summaryrefslogtreecommitdiff
path: root/packages/server/src/utils
diff options
context:
space:
mode:
Diffstat (limited to 'packages/server/src/utils')
-rw-r--r--packages/server/src/utils/editCorrector.test.ts33
-rw-r--r--packages/server/src/utils/editCorrector.ts37
-rw-r--r--packages/server/src/utils/nextSpeakerChecker.test.ts57
-rw-r--r--packages/server/src/utils/nextSpeakerChecker.ts2
4 files changed, 112 insertions, 17 deletions
diff --git a/packages/server/src/utils/editCorrector.test.ts b/packages/server/src/utils/editCorrector.test.ts
index 27c9ffe8..7d6f5a53 100644
--- a/packages/server/src/utils/editCorrector.test.ts
+++ b/packages/server/src/utils/editCorrector.test.ts
@@ -132,6 +132,7 @@ describe('editCorrector', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config;
+ const abortSignal = new AbortController().signal;
beforeEach(() => {
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
@@ -187,12 +188,18 @@ describe('editCorrector', () => {
callCount = 0;
mockResponses.length = 0;
- mockGenerateJson = vi.fn().mockImplementation(() => {
- const response = mockResponses[callCount];
- callCount++;
- if (response === undefined) return Promise.resolve({});
- return Promise.resolve(response);
- });
+ mockGenerateJson = vi
+ .fn()
+ .mockImplementation((_contents, _schema, signal) => {
+ // Check if the signal is aborted. If so, throw an error or return a specific response.
+ if (signal && signal.aborted) {
+ return Promise.reject(new Error('Aborted')); // Or some other specific error/response
+ }
+ const response = mockResponses[callCount];
+ callCount++;
+ if (response === undefined) return Promise.resolve({});
+ return Promise.resolve(response);
+ });
mockStartChat = vi.fn();
mockSendMessageStream = vi.fn();
@@ -217,6 +224,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
@@ -234,6 +242,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
@@ -254,6 +263,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
@@ -271,6 +281,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
@@ -292,6 +303,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
@@ -309,6 +321,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
@@ -329,6 +342,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with foobar');
@@ -351,6 +365,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(llmNewString);
@@ -372,6 +387,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.new_string).toBe(llmNewString);
@@ -391,6 +407,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
@@ -412,6 +429,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
@@ -432,6 +450,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params).toEqual(originalParams);
@@ -449,6 +468,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params).toEqual(originalParams);
@@ -471,6 +491,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
+ abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.old_string).toBe(currentContent);
diff --git a/packages/server/src/utils/editCorrector.ts b/packages/server/src/utils/editCorrector.ts
index 92551478..78663954 100644
--- a/packages/server/src/utils/editCorrector.ts
+++ b/packages/server/src/utils/editCorrector.ts
@@ -63,6 +63,7 @@ export async function ensureCorrectEdit(
currentContent: string,
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient,
+ abortSignal: AbortSignal,
): Promise<CorrectedEditResult> {
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
const cachedResult = editCorrectionCache.get(cacheKey);
@@ -84,6 +85,7 @@ export async function ensureCorrectEdit(
client,
finalOldString,
originalParams.new_string,
+ abortSignal,
);
}
} else if (occurrences > 1) {
@@ -108,6 +110,7 @@ export async function ensureCorrectEdit(
originalParams.old_string, // original old
unescapedOldStringAttempt, // corrected old
originalParams.new_string, // original new (which is potentially escaped)
+ abortSignal,
);
}
} else if (occurrences === 0) {
@@ -115,6 +118,7 @@ export async function ensureCorrectEdit(
client,
currentContent,
unescapedOldStringAttempt,
+ abortSignal,
);
const llmOldOccurrences = countOccurrences(
currentContent,
@@ -134,6 +138,7 @@ export async function ensureCorrectEdit(
originalParams.old_string, // original old
llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction
+ abortSignal,
);
}
} else {
@@ -180,6 +185,7 @@ export async function ensureCorrectEdit(
export async function ensureCorrectFileContent(
content: string,
client: GeminiClient,
+ abortSignal: AbortSignal,
): Promise<string> {
const cachedResult = fileContentCorrectionCache.get(content);
if (cachedResult) {
@@ -193,7 +199,11 @@ export async function ensureCorrectFileContent(
return content;
}
- const correctedContent = await correctStringEscaping(content, client);
+ const correctedContent = await correctStringEscaping(
+ content,
+ client,
+ abortSignal,
+ );
fileContentCorrectionCache.set(content, correctedContent);
return correctedContent;
}
@@ -215,6 +225,7 @@ export async function correctOldStringMismatch(
geminiClient: GeminiClient,
fileContent: string,
problematicSnippet: string,
+ abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
@@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
const result = await geminiClient.generateJson(
contents,
OLD_STRING_CORRECTION_SCHEMA,
+ abortSignal,
EditModel,
EditConfig,
);
@@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
return problematicSnippet;
}
} catch (error) {
+ if (abortSignal.aborted) {
+ throw error;
+ }
+
console.error(
'Error during LLM call for old string snippet correction:',
error,
);
+
return problematicSnippet;
}
}
@@ -286,6 +303,7 @@ export async function correctNewString(
originalOldString: string,
correctedOldString: string,
originalNewString: string,
+ abortSignal: AbortSignal,
): Promise<string> {
if (originalOldString === correctedOldString) {
return originalNewString;
@@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await geminiClient.generateJson(
contents,
NEW_STRING_CORRECTION_SCHEMA,
+ abortSignal,
EditModel,
EditConfig,
);
@@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return originalNewString;
}
} catch (error) {
+ if (abortSignal.aborted) {
+ throw error;
+ }
+
console.error('Error during LLM call for new_string correction:', error);
return originalNewString;
}
@@ -359,6 +382,7 @@ export async function correctNewStringEscaping(
geminiClient: GeminiClient,
oldString: string,
potentiallyProblematicNewString: string,
+ abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
@@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await geminiClient.generateJson(
contents,
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
+ abortSignal,
EditModel,
EditConfig,
);
@@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return potentiallyProblematicNewString;
}
} catch (error) {
+ if (abortSignal.aborted) {
+ throw error;
+ }
+
console.error(
'Error during LLM call for new_string escaping correction:',
error,
@@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
export async function correctStringEscaping(
potentiallyProblematicString: string,
client: GeminiClient,
+ abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
@@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await client.generateJson(
contents,
CORRECT_STRING_ESCAPING_SCHEMA,
+ abortSignal,
EditModel,
EditConfig,
);
@@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return potentiallyProblematicString;
}
} catch (error) {
+ if (abortSignal.aborted) {
+ throw error;
+ }
+
console.error(
'Error during LLM call for string escaping correction:',
error,
diff --git a/packages/server/src/utils/nextSpeakerChecker.test.ts b/packages/server/src/utils/nextSpeakerChecker.test.ts
index 1d87bffb..872e00f6 100644
--- a/packages/server/src/utils/nextSpeakerChecker.test.ts
+++ b/packages/server/src/utils/nextSpeakerChecker.test.ts
@@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => {
let chatInstance: GeminiChat;
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
+ const abortSignal = new AbortController().signal;
beforeEach(() => {
MockConfig = vi.mocked(Config);
@@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => {
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name
- {}, // config
+ {},
[], // initial history
);
@@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => {
it('should return null if history is empty', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([]);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
});
@@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'user', parts: [{ text: 'Hello' }] },
] as Content[]);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
});
@@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => {
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toEqual(mockApiResponse);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
});
@@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => {
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toEqual(mockApiResponse);
});
@@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => {
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toEqual(mockApiResponse);
});
@@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => {
new Error('API Error'),
);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
consoleWarnSpy.mockRestore();
});
@@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => {
reasoning: 'This is incomplete.',
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
});
@@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => {
next_speaker: 123, // Invalid type
} as unknown as NextSpeakerResponse);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
});
@@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => {
next_speaker: 'neither', // Invalid enum value
} as unknown as NextSpeakerResponse);
- const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
+ const result = await checkNextSpeaker(
+ chatInstance,
+ mockGeminiClient,
+ abortSignal,
+ );
expect(result).toBeNull();
});
});
diff --git a/packages/server/src/utils/nextSpeakerChecker.ts b/packages/server/src/utils/nextSpeakerChecker.ts
index fb00b39c..66fa4395 100644
--- a/packages/server/src/utils/nextSpeakerChecker.ts
+++ b/packages/server/src/utils/nextSpeakerChecker.ts
@@ -61,6 +61,7 @@ export interface NextSpeakerResponse {
export async function checkNextSpeaker(
chat: GeminiChat,
geminiClient: GeminiClient,
+ abortSignal: AbortSignal,
): Promise<NextSpeakerResponse | null> {
// We need to capture the curated history because there are many moments when the model will return invalid turns
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
@@ -129,6 +130,7 @@ export async function checkNextSpeaker(
const parsedResponse = (await geminiClient.generateJson(
contents,
RESPONSE_SCHEMA,
+ abortSignal,
)) as unknown as NextSpeakerResponse;
if (