summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorN. Taylor Mullen <[email protected]>2025-06-26 16:51:32 +0200
committerGitHub <[email protected]>2025-06-26 14:51:32 +0000
commit24ccc9c4578f40317ee903f731831f42eed699d4 (patch)
treeaba5073b4371f668fb8cbce91c21cedace377c1e /packages/core/src
parent121bba346411cce23e350b833dc5857ea2239f2f (diff)
feat: Add model selection logic (#1678)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/geminiChat.test.ts8
-rw-r--r--packages/core/src/core/geminiChat.ts91
2 files changed, 94 insertions, 5 deletions
diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts
index bfaeb8f6..67fa676d 100644
--- a/packages/core/src/core/geminiChat.test.ts
+++ b/packages/core/src/core/geminiChat.test.ts
@@ -14,6 +14,7 @@ import {
} from '@google/genai';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
+import { AuthType } from '../core/contentGenerator.js';
import { setSimulate429 } from '../utils/testUtils.js';
// Mocks
@@ -38,11 +39,14 @@ describe('GeminiChat', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
- authType: 'oauth-personal',
+ authType: AuthType.USE_GEMINI,
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(),
+ getGeminiClient: vi.fn().mockReturnValue({
+ generateJson: vi.fn().mockResolvedValue({ model: 'pro' }),
+ }),
flashFallbackHandler: undefined,
} as unknown as Config;
@@ -110,7 +114,7 @@ describe('GeminiChat', () => {
await chat.sendMessageStream({ message: 'hello' });
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
- model: 'gemini-pro',
+ model: 'gemini-2.5-pro',
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
config: {},
});
diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts
index 19b87805..770f8bb6 100644
--- a/packages/core/src/core/geminiChat.ts
+++ b/packages/core/src/core/geminiChat.ts
@@ -34,7 +34,10 @@ import {
ApiRequestEvent,
ApiResponseEvent,
} from '../telemetry/types.js';
-import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
+import {
+ DEFAULT_GEMINI_FLASH_MODEL,
+ DEFAULT_GEMINI_MODEL,
+} from '../config/models.js';
/**
* Returns true if the response is valid, false otherwise.
@@ -346,14 +349,20 @@ export class GeminiChat {
await this.sendPromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
- this._logApiRequest(requestContents, this.config.getModel());
+
+ const model = await this._selectModel(
+ requestContents,
+ params.config?.abortSignal ?? new AbortController().signal,
+ );
+
+ this._logApiRequest(requestContents, model);
const startTime = Date.now();
try {
const apiCall = () =>
this.contentGenerator.generateContentStream({
- model: this.config.getModel(),
+ model,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
@@ -398,6 +407,82 @@ export class GeminiChat {
}
/**
+ * Selects the model to use for the request.
+ *
+ * This is a placeholder for now.
+ */
+ private async _selectModel(
+ history: Content[],
+ signal: AbortSignal,
+ ): Promise<string> {
+ const currentModel = this.config.getModel();
+ if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) {
+ return DEFAULT_GEMINI_FLASH_MODEL;
+ }
+
+ if (
+ history.length < 5 &&
+ this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI
+ ) {
+ // There's currently a bug where for Gemini API key usage if we try and use flash as one of the first
+ // requests in our sequence that it will return an empty token.
+ return DEFAULT_GEMINI_MODEL;
+ }
+
+ const flashIndicator = 'flash';
+ const proIndicator = 'pro';
+ const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests.
+
+Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used.
+
+For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`;
+ const modelChoiceContent: Content[] = [
+ {
+ role: 'user',
+ parts: [{ text: modelChoicePrompt }],
+ },
+ ];
+
+ const client = this.config.getGeminiClient();
+ try {
+ const choice = await client.generateJson(
+ [...history, ...modelChoiceContent],
+ {
+ type: 'object',
+ properties: {
+ model: {
+ type: 'string',
+ enum: [flashIndicator, proIndicator],
+ },
+ },
+ required: ['model'],
+ },
+ signal,
+ DEFAULT_GEMINI_FLASH_MODEL,
+ {
+ temperature: 0,
+ maxOutputTokens: 25,
+ thinkingConfig: {
+ thinkingBudget: 0,
+ },
+ },
+ );
+
+ switch (choice.model) {
+ case flashIndicator:
+ return DEFAULT_GEMINI_FLASH_MODEL;
+ case proIndicator:
+ return DEFAULT_GEMINI_MODEL;
+ default:
+ return currentModel;
+ }
+ } catch (_e) {
+ // If the model selection fails, just use the default flash model.
+ return DEFAULT_GEMINI_FLASH_MODEL;
+ }
+ }
+
+ /**
* Returns the chat history.
*
* @remarks