summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/core/src/utils/nextSpeakerChecker.test.ts19
-rw-r--r--packages/core/src/utils/nextSpeakerChecker.ts2
2 files changed, 21 insertions, 0 deletions
diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts
index 475b5662..9141105f 100644
--- a/packages/core/src/utils/nextSpeakerChecker.test.ts
+++ b/packages/core/src/utils/nextSpeakerChecker.test.ts
@@ -6,6 +6,7 @@
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
import { Content, GoogleGenAI, Models } from '@google/genai';
+import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js';
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
@@ -231,4 +232,22 @@ describe('checkNextSpeaker', () => {
);
expect(result).toBeNull();
});
+
+ it('should call generateJson with DEFAULT_GEMINI_FLASH_MODEL', async () => {
+ (chatInstance.getHistory as Mock).mockReturnValue([
+ { role: 'model', parts: [{ text: 'Some model output.' }] },
+ ] as Content[]);
+ const mockApiResponse: NextSpeakerResponse = {
+ reasoning: 'Model made a statement, awaiting user input.',
+ next_speaker: 'user',
+ };
+ (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
+
+ await checkNextSpeaker(chatInstance, mockGeminiClient, abortSignal);
+
+ expect(mockGeminiClient.generateJson).toHaveBeenCalled();
+ const generateJsonCall = (mockGeminiClient.generateJson as Mock).mock
+ .calls[0];
+ expect(generateJsonCall[3]).toBe(DEFAULT_GEMINI_FLASH_MODEL);
+ });
});
diff --git a/packages/core/src/utils/nextSpeakerChecker.ts b/packages/core/src/utils/nextSpeakerChecker.ts
index 165f277a..9d428887 100644
--- a/packages/core/src/utils/nextSpeakerChecker.ts
+++ b/packages/core/src/utils/nextSpeakerChecker.ts
@@ -5,6 +5,7 @@
*/
import { Content, SchemaUnion, Type } from '@google/genai';
+import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { GeminiChat } from '../core/geminiChat.js';
import { isFunctionResponse } from './messageInspectors.js';
@@ -131,6 +132,7 @@ export async function checkNextSpeaker(
contents,
RESPONSE_SCHEMA,
abortSignal,
+ DEFAULT_GEMINI_FLASH_MODEL,
)) as unknown as NextSpeakerResponse;
if (