summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/cli/src/nonInteractiveCli.test.ts3
-rw-r--r--packages/cli/src/nonInteractiveCli.ts2
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.test.tsx10
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts37
-rw-r--r--packages/core/src/config/config.test.ts1
-rw-r--r--packages/core/src/core/client.ts15
-rw-r--r--packages/core/src/tools/tool-registry.ts4
7 files changed, 26 insertions, 46 deletions
diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts
index 389d35f2..c071dedc 100644
--- a/packages/cli/src/nonInteractiveCli.test.ts
+++ b/packages/cli/src/nonInteractiveCli.test.ts
@@ -39,7 +39,7 @@ describe('runNonInteractive', () => {
sendMessageStream: vi.fn(),
};
mockGeminiClient = {
- startChat: vi.fn().mockResolvedValue(mockChat),
+ getChat: vi.fn().mockResolvedValue(mockChat),
} as unknown as GeminiClient;
mockToolRegistry = {
getFunctionDeclarations: vi.fn().mockReturnValue([]),
@@ -80,7 +80,6 @@ describe('runNonInteractive', () => {
await runNonInteractive(mockConfig, 'Test input');
- expect(mockGeminiClient.startChat).toHaveBeenCalled();
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
message: [{ text: 'Test input' }],
config: {
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts
index f7b4108b..7505c736 100644
--- a/packages/cli/src/nonInteractiveCli.ts
+++ b/packages/cli/src/nonInteractiveCli.ts
@@ -42,7 +42,7 @@ export async function runNonInteractive(
const geminiClient = new GeminiClient(config);
const toolRegistry: ToolRegistry = await config.getToolRegistry();
- const chat = await geminiClient.startChat();
+ const chat = await geminiClient.getChat();
const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index 3a421ebf..d46fab9e 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -405,10 +405,9 @@ describe('useGeminiStream', () => {
} as TrackedCancelledToolCall,
];
- let hookResult: any;
- await act(async () => {
- hookResult = renderTestHook(simplifiedToolCalls);
- });
+ const hookResult = await act(async () =>
+ renderTestHook(simplifiedToolCalls),
+ );
const {
mockMarkToolsAsSubmitted,
@@ -431,9 +430,8 @@ describe('useGeminiStream', () => {
toolCall2ResponseParts,
]);
expect(localMockSendMessageStream).toHaveBeenCalledWith(
- expect.anything(),
expectedMergedResponse,
- expect.anything(),
+ expect.any(AbortSignal),
);
});
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 423f3489..284709cf 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -17,7 +17,6 @@ import {
Config,
MessageSenderType,
ToolCallRequestInfo,
- GeminiChat,
} from '@gemini-code/core';
import { type PartListUnion } from '@google/genai';
import {
@@ -76,7 +75,6 @@ export const useGeminiStream = (
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
- const chatSessionRef = useRef<GeminiChat | null>(null);
const geminiClientRef = useRef<GeminiClient | null>(null);
const [isResponding, setIsResponding] = useState<boolean>(false);
const [pendingHistoryItemRef, setPendingHistoryItem] =
@@ -256,31 +254,6 @@ export const useGeminiStream = (
],
);
- const ensureChatSession = useCallback(async (): Promise<{
- client: GeminiClient | null;
- chat: GeminiChat | null;
- }> => {
- const currentClient = geminiClientRef.current;
- if (!currentClient) {
- const errorMsg = 'Gemini client is not available.';
- setInitError(errorMsg);
- addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
- return { client: null, chat: null };
- }
-
- if (!chatSessionRef.current) {
- try {
- chatSessionRef.current = await currentClient.startChat();
- } catch (err: unknown) {
- const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`;
- setInitError(errorMsg);
- addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
- return { client: currentClient, chat: null };
- }
- }
- return { client: currentClient, chat: chatSessionRef.current };
- }, [addItem]);
-
// --- Stream Event Handlers ---
const handleContentEvent = useCallback(
@@ -444,9 +417,12 @@ export const useGeminiStream = (
return;
}
- const { client, chat } = await ensureChatSession();
+ const client = geminiClientRef.current;
- if (!client || !chat) {
+ if (!client) {
+ const errorMsg = 'Gemini client is not available.';
+ setInitError(errorMsg);
+ addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
return;
}
@@ -454,7 +430,7 @@ export const useGeminiStream = (
setInitError(null);
try {
- const stream = client.sendMessageStream(chat, queryToSend, abortSignal);
+ const stream = client.sendMessageStream(queryToSend, abortSignal);
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
@@ -487,7 +463,6 @@ export const useGeminiStream = (
streamingState,
setShowHelp,
prepareQueryForGemini,
- ensureChatSession,
processGeminiStreamEvents,
pendingHistoryItemRef,
addItem,
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts
index c3c46659..9b4b2664 100644
--- a/packages/core/src/config/config.test.ts
+++ b/packages/core/src/config/config.test.ts
@@ -35,6 +35,7 @@ vi.mock('../tools/memoryTool', () => ({
setGeminiMdFilename: vi.fn(),
getCurrentGeminiMdFilename: vi.fn(() => 'GEMINI.md'), // Mock the original filename
DEFAULT_CONTEXT_FILENAME: 'GEMINI.md',
+ GEMINI_CONFIG_DIR: '.gemini',
}));
describe('Server Config (config.ts)', () => {
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 732126cb..fcad1ef0 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -27,6 +27,7 @@ import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js';
export class GeminiClient {
+ private chat: Promise<GeminiChat>;
private client: GoogleGenAI;
private model: string;
private generateContentConfig: GenerateContentConfig = {
@@ -50,6 +51,11 @@ export class GeminiClient {
},
});
this.model = config.getModel();
+ this.chat = this.startChat();
+ }
+
+ getChat(): Promise<GeminiChat> {
+ return this.chat;
}
private async getEnvironment(): Promise<Part[]> {
@@ -114,12 +120,12 @@ export class GeminiClient {
return initialParts;
}
- async startChat(): Promise<GeminiChat> {
+ private async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
const envParts = await this.getEnvironment();
const toolRegistry = await this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
- const history: Content[] = [
+ const initialHistory: Content[] = [
{
role: 'user',
parts: envParts,
@@ -129,6 +135,7 @@ export class GeminiClient {
parts: [{ text: 'Got it. Thanks for the context!' }],
},
];
+ const history = initialHistory.concat(extraHistory ?? []);
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
@@ -157,7 +164,6 @@ export class GeminiClient {
}
async *sendMessageStream(
- chat: GeminiChat,
request: PartListUnion,
signal: AbortSignal,
turns: number = this.MAX_TURNS,
@@ -166,6 +172,7 @@ export class GeminiClient {
return;
}
+ const chat = await this.chat;
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
@@ -175,7 +182,7 @@ export class GeminiClient {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
- yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
+ yield* this.sendMessageStream(nextRequest, signal, turns - 1);
}
}
}
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index 21aec687..12aa1a83 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -56,10 +56,10 @@ Signal: Signal number or \`(none)\` if no signal was received.
let stdout = '';
let stderr = '';
child.stdout.on('data', (data) => {
- stdout += data.toString();
+ stdout += data?.toString();
});
child.stderr.on('data', (data) => {
- stderr += data.toString();
+ stderr += data?.toString();
});
let error: Error | null = null;
child.on('error', (err: Error) => {