summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/cli/configuration.md11
-rw-r--r--packages/cli/src/config/config.ts1
-rw-r--r--packages/cli/src/config/settings.ts3
-rw-r--r--packages/cli/src/nonInteractiveCli.test.ts47
-rw-r--r--packages/cli/src/nonInteractiveCli.ts12
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts18
-rw-r--r--packages/core/src/config/config.ts7
-rw-r--r--packages/core/src/core/client.test.ts130
-rw-r--r--packages/core/src/core/client.ts9
-rw-r--r--packages/core/src/core/turn.ts8
10 files changed, 231 insertions, 15 deletions
diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md
index 79a2ffc3..d175aa4f 100644
--- a/docs/cli/configuration.md
+++ b/docs/cli/configuration.md
@@ -189,6 +189,14 @@ In addition to a project settings file, a project's `.gemini` directory can cont
"hideTips": true
```
+- **`maxSessionTurns`** (number):
+ - **Description:** Sets the maximum number of turns for a session. If the session exceeds this limit, the CLI will stop processing and start a new chat.
+ - **Default:** `-1` (unlimited)
+ - **Example:**
+ ```json
+ "maxSessionTurns": 10
+ ```
+
### Example `settings.json`:
```json
@@ -213,7 +221,8 @@ In addition to a project settings file, a project's `.gemini` directory can cont
"logPrompts": true
},
"usageStatisticsEnabled": true,
- "hideTips": false
+ "hideTips": false,
+ "maxSessionTurns": 10
}
```
diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts
index b80b6dd0..b685f090 100644
--- a/packages/cli/src/config/config.ts
+++ b/packages/cli/src/config/config.ts
@@ -312,6 +312,7 @@ export async function loadCliConfig(
bugCommand: settings.bugCommand,
model: argv.model!,
extensionContextFilePaths,
+ maxSessionTurns: settings.maxSessionTurns ?? -1,
listExtensions: argv.listExtensions || false,
activeExtensions: activeExtensions.map((e) => ({
name: e.config.name,
diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts
index 133701f5..2abe8cd8 100644
--- a/packages/cli/src/config/settings.ts
+++ b/packages/cli/src/config/settings.ts
@@ -80,6 +80,9 @@ export interface Settings {
hideWindowTitle?: boolean;
hideTips?: boolean;
+ // Setting for setting maximum number of user/model/tool turns in a session.
+ maxSessionTurns?: number;
+
// Add other settings here.
}
diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts
index 14352f53..6cbb630d 100644
--- a/packages/cli/src/nonInteractiveCli.test.ts
+++ b/packages/cli/src/nonInteractiveCli.test.ts
@@ -53,6 +53,7 @@ describe('runNonInteractive', () => {
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
+ getMaxSessionTurns: vi.fn().mockReturnValue(10),
initialize: vi.fn(),
} as unknown as Config;
@@ -294,4 +295,50 @@ describe('runNonInteractive', () => {
'Unfortunately the tool does not exist.',
);
});
+
+ it('should exit when max session turns are exceeded', async () => {
+ const functionCall: FunctionCall = {
+ id: 'fcLoop',
+ name: 'loopTool',
+ args: {},
+ };
+ const toolResponsePart: Part = {
+ functionResponse: {
+ name: 'loopTool',
+ id: 'fcLoop',
+ response: { result: 'still looping' },
+ },
+ };
+
+ // Config with a max turn of 1
+ vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1);
+
+ const { executeToolCall: mockCoreExecuteToolCall } = await import(
+ '@google/gemini-cli-core'
+ );
+ vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
+ callId: 'fcLoop',
+ responseParts: [toolResponsePart],
+ resultDisplay: 'Still looping',
+ error: undefined,
+ });
+
+ const stream = (async function* () {
+ yield { functionCalls: [functionCall] } as GenerateContentResponse;
+ })();
+
+ mockChat.sendMessageStream.mockResolvedValue(stream);
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {});
+
+ await runNonInteractive(mockConfig, 'Trigger loop');
+
+ expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1);
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ `
+ Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`,
+ );
+ expect(mockProcessExit).not.toHaveBeenCalled();
+ });
});
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts
index b8b8ac3f..2db28eba 100644
--- a/packages/cli/src/nonInteractiveCli.ts
+++ b/packages/cli/src/nonInteractiveCli.ts
@@ -63,9 +63,19 @@ export async function runNonInteractive(
const chat = await geminiClient.getChat();
const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
-
+ let turnCount = 0;
try {
while (true) {
+ turnCount++;
+ if (
+ config.getMaxSessionTurns() > 0 &&
+ turnCount > config.getMaxSessionTurns()
+ ) {
+ console.error(
+ '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
+ );
+ return;
+ }
const functionCalls: FunctionCall[] = [];
const responseStream = await chat.sendMessageStream(
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index b82b0cb2..a9326528 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -431,6 +431,20 @@ export const useGeminiStream = (
[addItem, config],
);
+ const handleMaxSessionTurnsEvent = useCallback(
+ () =>
+ addItem(
+ {
+ type: 'info',
+ text:
+ `The session has reached the maximum number of turns: ${config.getMaxSessionTurns()}. ` +
+ `Please update this limit in your setting.json file.`,
+ },
+ Date.now(),
+ ),
+ [addItem, config],
+ );
+
const processGeminiStreamEvents = useCallback(
async (
stream: AsyncIterable<GeminiEvent>,
@@ -467,6 +481,9 @@ export const useGeminiStream = (
case ServerGeminiEventType.ToolCallResponse:
// do nothing
break;
+ case ServerGeminiEventType.MaxSessionTurns:
+ handleMaxSessionTurnsEvent();
+ break;
default: {
// enforces exhaustive switch-case
const unreachable: never = event;
@@ -485,6 +502,7 @@ export const useGeminiStream = (
handleErrorEvent,
scheduleToolCalls,
handleChatCompressionEvent,
+ handleMaxSessionTurnsEvent,
],
);
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 15e9e73b..12767133 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -139,6 +139,7 @@ export interface ConfigParameters {
bugCommand?: BugCommandSettings;
model: string;
extensionContextFilePaths?: string[];
+ maxSessionTurns?: number;
listExtensions?: boolean;
activeExtensions?: ActiveExtension[];
noBrowser?: boolean;
@@ -182,6 +183,7 @@ export class Config {
private readonly extensionContextFilePaths: string[];
private readonly noBrowser: boolean;
private modelSwitchedDuringSession: boolean = false;
+ private readonly maxSessionTurns: number;
private readonly listExtensions: boolean;
private readonly _activeExtensions: ActiveExtension[];
flashFallbackHandler?: FlashFallbackHandler;
@@ -227,6 +229,7 @@ export class Config {
this.bugCommand = params.bugCommand;
this.model = params.model;
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
+ this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.listExtensions = params.listExtensions ?? false;
this._activeExtensions = params.activeExtensions ?? [];
this.noBrowser = params.noBrowser ?? false;
@@ -308,6 +311,10 @@ export class Config {
this.flashFallbackHandler = handler;
}
+ getMaxSessionTurns(): number {
+ return this.maxSessionTurns;
+ }
+
setQuotaErrorOccurred(value: boolean): void {
this.quotaErrorOccurred = value;
}
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index 2769e1b0..bbcb549b 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -17,7 +17,7 @@ import { findIndexAfterFraction, GeminiClient } from './client.js';
import { AuthType, ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
-import { Turn } from './turn.js';
+import { GeminiEventType, Turn } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
@@ -43,7 +43,13 @@ vi.mock('./turn', () => {
}
}
// Export the mock class as 'Turn'
- return { Turn: MockTurn };
+ return {
+ Turn: MockTurn,
+ GeminiEventType: {
+ MaxSessionTurns: 'MaxSessionTurns',
+ ChatCompressed: 'ChatCompressed',
+ },
+ };
});
vi.mock('../config/config.js');
@@ -68,12 +74,13 @@ vi.mock('../telemetry/index.js', () => ({
describe('findIndexAfterFraction', () => {
const history: Content[] = [
- { role: 'user', parts: [{ text: 'This is the first message.' }] },
- { role: 'model', parts: [{ text: 'This is the second message.' }] },
- { role: 'user', parts: [{ text: 'This is the third message.' }] },
- { role: 'model', parts: [{ text: 'This is the fourth message.' }] },
- { role: 'user', parts: [{ text: 'This is the fifth message.' }] },
+ { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
+ { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68
+ { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66
+ { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68
+ { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65
];
+ // Total length: 333
it('should throw an error for non-positive numbers', () => {
expect(() => findIndexAfterFraction(history, 0)).toThrow(
@@ -88,14 +95,23 @@ describe('findIndexAfterFraction', () => {
});
it('should handle a fraction in the middle', () => {
- // Total length is 257. 257 * 0.5 = 128.5
- // 0: 53
- // 1: 53 + 54 = 107
- // 2: 107 + 53 = 160
- // 160 >= 128.5, so index is 2
+ // 333 * 0.5 = 166.5
+ // 0: 66
+ // 1: 66 + 68 = 134
+ // 2: 134 + 66 = 200
+ // 200 >= 166.5, so index is 2
expect(findIndexAfterFraction(history, 0.5)).toBe(2);
});
+ it('should handle a fraction that results in the last index', () => {
+ // 333 * 0.9 = 299.7
+ // ...
+ // 3: 200 + 68 = 268
+ // 4: 268 + 65 = 333
+ // 333 >= 299.7, so index is 4
+ expect(findIndexAfterFraction(history, 0.9)).toBe(4);
+ });
+
it('should handle an empty history', () => {
expect(findIndexAfterFraction([], 0.5)).toBe(0);
});
@@ -178,6 +194,7 @@ describe('Gemini Client (client.ts)', () => {
getProxy: vi.fn().mockReturnValue(undefined),
getWorkingDir: vi.fn().mockReturnValue('/test/dir'),
getFileService: vi.fn().mockReturnValue(fileService),
+ getMaxSessionTurns: vi.fn().mockReturnValue(0),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
getNoBrowser: vi.fn().mockReturnValue(false),
@@ -366,6 +383,42 @@ describe('Gemini Client (client.ts)', () => {
contents,
});
});
+
+ it('should allow overriding model and config', async () => {
+ const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
+ const schema = { type: 'string' };
+ const abortSignal = new AbortController().signal;
+ const customModel = 'custom-json-model';
+ const customConfig = { temperature: 0.9, topK: 20 };
+
+ const mockGenerator: Partial<ContentGenerator> = {
+ countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
+ generateContent: mockGenerateContentFn,
+ };
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+
+ await client.generateJson(
+ contents,
+ schema,
+ abortSignal,
+ customModel,
+ customConfig,
+ );
+
+ expect(mockGenerateContentFn).toHaveBeenCalledWith({
+ model: customModel,
+ config: {
+ abortSignal,
+ systemInstruction: getCoreSystemPrompt(''),
+ temperature: 0.9,
+ topP: 1, // from default
+ topK: 20,
+ responseSchema: schema,
+ responseMimeType: 'application/json',
+ },
+ contents,
+ });
+ });
});
describe('addHistory', () => {
@@ -660,6 +713,59 @@ describe('Gemini Client (client.ts)', () => {
expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit
});
+ it('should yield MaxSessionTurns and stop when session turn limit is reached', async () => {
+ // Arrange
+ const MAX_SESSION_TURNS = 5;
+ vi.spyOn(client['config'], 'getMaxSessionTurns').mockReturnValue(
+ MAX_SESSION_TURNS,
+ );
+
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Hello' };
+ })();
+ mockTurnRunFn.mockReturnValue(mockStream);
+
+ const mockChat: Partial<GeminiChat> = {
+ addHistory: vi.fn(),
+ getHistory: vi.fn().mockReturnValue([]),
+ };
+ client['chat'] = mockChat as GeminiChat;
+
+ const mockGenerator: Partial<ContentGenerator> = {
+ countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
+ };
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+
+ // Act & Assert
+ // Run up to the limit
+ for (let i = 0; i < MAX_SESSION_TURNS; i++) {
+ const stream = client.sendMessageStream(
+ [{ text: 'Hi' }],
+ new AbortController().signal,
+ 'prompt-id-4',
+ );
+ // consume stream
+ for await (const _event of stream) {
+ // do nothing
+ }
+ }
+
+ // This call should exceed the limit
+ const stream = client.sendMessageStream(
+ [{ text: 'Hi' }],
+ new AbortController().signal,
+ 'prompt-id-5',
+ );
+
+ const events = [];
+ for await (const event of stream) {
+ events.push(event);
+ }
+
+ expect(events).toEqual([{ type: GeminiEventType.MaxSessionTurns }]);
+ expect(mockTurnRunFn).toHaveBeenCalledTimes(MAX_SESSION_TURNS);
+ });
+
it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => {
// This test verifies that the infinite loop protection works even when
// someone tries to bypass it by calling with a very large turns value
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 5d9ac0cb..0ff8026b 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -86,6 +86,7 @@ export class GeminiClient {
temperature: 0,
topP: 1,
};
+ private sessionTurnCount = 0;
private readonly MAX_TURNS = 100;
/**
* Threshold for compression token count as a fraction of the model's token limit.
@@ -266,6 +267,14 @@ export class GeminiClient {
turns: number = this.MAX_TURNS,
originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
+ this.sessionTurnCount++;
+ if (
+ this.config.getMaxSessionTurns() > 0 &&
+ this.sessionTurnCount > this.config.getMaxSessionTurns()
+ ) {
+ yield { type: GeminiEventType.MaxSessionTurns };
+ return new Turn(this.getChat(), prompt_id);
+ }
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, this.MAX_TURNS);
if (!boundedTurns) {
diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts
index aeeaa889..6135b1f6 100644
--- a/packages/core/src/core/turn.ts
+++ b/packages/core/src/core/turn.ts
@@ -48,6 +48,7 @@ export enum GeminiEventType {
Error = 'error',
ChatCompressed = 'chat_compressed',
Thought = 'thought',
+ MaxSessionTurns = 'max_session_turns',
}
export interface StructuredError {
@@ -128,6 +129,10 @@ export type ServerGeminiChatCompressedEvent = {
value: ChatCompressionInfo | null;
};
+export type ServerGeminiMaxSessionTurnsEvent = {
+ type: GeminiEventType.MaxSessionTurns;
+};
+
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent
@@ -137,7 +142,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent
- | ServerGeminiThoughtEvent;
+ | ServerGeminiThoughtEvent
+ | ServerGeminiMaxSessionTurnsEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {