diff options
Diffstat (limited to 'packages/cli/src/ui/hooks')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 314 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 5 |
2 files changed, 312 insertions, 7 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 44013059..3a421ebf 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -4,19 +4,102 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi } from 'vitest'; -import { mergePartListUnions } from './useGeminiStream.js'; +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; +import { renderHook, act, waitFor } from '@testing-library/react'; +import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; +import { + useReactToolScheduler, + TrackedToolCall, + TrackedCompletedToolCall, + TrackedExecutingToolCall, + TrackedCancelledToolCall, +} from './useReactToolScheduler.js'; +import { Config } from '@gemini-code/core'; import { Part, PartListUnion } from '@google/genai'; +import { UseHistoryManagerReturn } from './useHistoryManager.js'; -// Mock useToolScheduler -vi.mock('./useReactToolScheduler', async () => { - const actual = await vi.importActual('./useReactToolScheduler'); +// --- MOCKS --- +const mockSendMessageStream = vi + .fn() + .mockReturnValue((async function* () {})()); +const mockStartChat = vi.fn(); + +vi.mock('@gemini-code/core', async (importOriginal) => { + const actualCoreModule = (await importOriginal()) as any; + const MockedGeminiClientClass = vi.fn().mockImplementation(function ( + this: any, + _config: any, + ) { + // _config + this.startChat = mockStartChat; + this.sendMessageStream = mockSendMessageStream; + }); + return { + ...(actualCoreModule || {}), + GeminiClient: MockedGeminiClientClass, + // GeminiChat will be from actualCoreModule if it exists, otherwise undefined + }; +}); + +const mockUseReactToolScheduler = useReactToolScheduler as Mock; +vi.mock('./useReactToolScheduler.js', async (importOriginal) => { + const actualSchedulerModule = (await importOriginal()) as any; return { - ...actual, // We need mapToDisplay from actual + ...(actualSchedulerModule || {}), useReactToolScheduler: vi.fn(), }; }); +vi.mock('ink', async (importOriginal) => { + const actualInkModule = (await importOriginal()) as any; + return { ...(actualInkModule || {}), useInput: vi.fn() }; +}); + +vi.mock('./shellCommandProcessor.js', () => ({ + useShellCommandProcessor: vi.fn().mockReturnValue({ + handleShellCommand: vi.fn(), + }), +})); + +vi.mock('./atCommandProcessor.js', () => ({ + handleAtCommand: vi + .fn() + .mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }), +})); + +vi.mock('../utils/markdownUtilities.js', () => ({ + findLastSafeSplitPoint: vi.fn((s: string) => s.length), +})); + +vi.mock('./useStateAndRef.js', () => ({ + useStateAndRef: vi.fn((initial) => { + let val = initial; + const ref = { current: val }; + const setVal = vi.fn((updater) => { + if (typeof updater === 'function') { + val = updater(val); + } else { + val = updater; + } + ref.current = val; + }); + return [ref, setVal]; + }), +})); + +vi.mock('./useLogger.js', () => ({ + useLogger: vi.fn().mockReturnValue({ + logMessage: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock('./slashCommandProcessor.js', () => ({ + handleSlashCommand: vi.fn().mockReturnValue(false), +})); + +// --- END MOCKS --- + describe('mergePartListUnions', () => { it('should merge multiple PartListUnion arrays', () => { const list1: PartListUnion = [{ text: 'Hello' }]; @@ -135,3 +218,222 @@ describe('mergePartListUnions', () => { ]); }); }); + +// --- Tests for useGeminiStream Hook --- +describe('useGeminiStream', () => { + let mockAddItem: Mock; + let mockSetShowHelp: Mock; + let mockConfig: Config; + let mockOnDebugMessage: Mock; + let mockHandleSlashCommand: Mock; + let mockScheduleToolCalls: Mock; + let mockCancelAllToolCalls: Mock; + let mockMarkToolsAsSubmitted: Mock; + + beforeEach(() => { + vi.clearAllMocks(); // Clear mocks before each test + + mockAddItem = vi.fn(); + mockSetShowHelp = vi.fn(); + mockConfig = { + apiKey: 'test-api-key', + model: 'gemini-pro', + sandbox: false, + targetDir: '/test/dir', + debugMode: false, + question: undefined, + fullContext: false, + coreTools: [], + toolDiscoveryCommand: undefined, + toolCallCommand: undefined, + mcpServerCommand: undefined, + mcpServers: undefined, + userAgent: 'test-agent', + userMemory: '', + geminiMdFileCount: 0, + alwaysSkipModificationConfirmation: false, + vertexai: false, + showMemoryUsage: false, + contextFileName: undefined, + getToolRegistry: vi.fn( + () => ({ getToolSchemaList: vi.fn(() => []) }) as any, + ), + } as unknown as Config; + mockOnDebugMessage = vi.fn(); + mockHandleSlashCommand = vi.fn().mockReturnValue(false); + + // Mock return value for useReactToolScheduler + mockScheduleToolCalls = vi.fn(); + mockCancelAllToolCalls = vi.fn(); + mockMarkToolsAsSubmitted = vi.fn(); + + // Default mock for useReactToolScheduler to prevent toolCalls being undefined initially + mockUseReactToolScheduler.mockReturnValue([ + [], // Default to empty array for toolCalls + mockScheduleToolCalls, + mockCancelAllToolCalls, + mockMarkToolsAsSubmitted, + ]); + + // Reset mocks for GeminiClient instance methods (startChat and sendMessageStream) + // The GeminiClient constructor itself is mocked at the module level. + mockStartChat.mockClear().mockResolvedValue({ + sendMessageStream: mockSendMessageStream, + } as unknown as any); // GeminiChat -> any + mockSendMessageStream + .mockClear() + .mockReturnValue((async function* () {})()); + }); + + const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => { + mockUseReactToolScheduler.mockReturnValue([ + initialToolCalls, + mockScheduleToolCalls, + mockCancelAllToolCalls, + mockMarkToolsAsSubmitted, + ]); + + const { result, rerender } = renderHook(() => + useGeminiStream( + mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + mockSetShowHelp, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, // shellModeActive + ), + ); + return { + result, + rerender, + mockMarkToolsAsSubmitted, + mockSendMessageStream, + // mockFilter removed + }; + }; + + it('should not submit tool responses if not all tool calls are completed', () => { + const toolCalls: TrackedToolCall[] = [ + { + request: { callId: 'call1', name: 'tool1', args: {} }, + status: 'success', + responseSubmittedToGemini: false, + response: { + callId: 'call1', + responseParts: [{ text: 'tool 1 response' }], + error: undefined, + resultDisplay: 'Tool 1 success display', + }, + tool: { + name: 'tool1', + description: 'desc1', + getDescription: vi.fn(), + } as any, + startTime: Date.now(), + endTime: Date.now(), + } as TrackedCompletedToolCall, + { + request: { callId: 'call2', name: 'tool2', args: {} }, + status: 'executing', + responseSubmittedToGemini: false, + tool: { + name: 'tool2', + description: 'desc2', + getDescription: vi.fn(), + } as any, + startTime: Date.now(), + liveOutput: '...', + } as TrackedExecutingToolCall, + ]; + + const { mockMarkToolsAsSubmitted, mockSendMessageStream } = + renderTestHook(toolCalls); + + // Effect for submitting tool responses depends on toolCalls and isResponding + // isResponding is initially false, so the effect should run. + + expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalled(); + expect(mockSendMessageStream).not.toHaveBeenCalled(); // submitQuery uses this + }); + + it('should submit tool responses when all tool calls are completed and ready', async () => { + const toolCall1ResponseParts: PartListUnion = [ + { text: 'tool 1 final response' }, + ]; + const toolCall2ResponseParts: PartListUnion = [ + { text: 'tool 2 final response' }, + ]; + + // Simplified toolCalls to ensure the filter logic is the focus + const simplifiedToolCalls: TrackedToolCall[] = [ + { + request: { callId: 'call1', name: 'tool1', args: {} }, + status: 'success', + responseSubmittedToGemini: false, + response: { + callId: 'call1', + responseParts: toolCall1ResponseParts, + error: undefined, + resultDisplay: 'Tool 1 success display', + }, + tool: { + name: 'tool1', + description: 'desc', + getDescription: vi.fn(), + } as any, + startTime: Date.now(), + endTime: Date.now(), + } as TrackedCompletedToolCall, + { + request: { callId: 'call2', name: 'tool2', args: {} }, + status: 'cancelled', + responseSubmittedToGemini: false, + response: { + callId: 'call2', + responseParts: toolCall2ResponseParts, + error: undefined, + resultDisplay: 'Tool 2 cancelled display', + }, + tool: { + name: 'tool2', + description: 'desc', + getDescription: vi.fn(), + } as any, + startTime: Date.now(), + endTime: Date.now(), + reason: 'test cancellation', + } as TrackedCancelledToolCall, + ]; + + let hookResult: any; + await act(async () => { + hookResult = renderTestHook(simplifiedToolCalls); + }); + + const { + mockMarkToolsAsSubmitted, + mockSendMessageStream: localMockSendMessageStream, + } = hookResult!; + + // It seems the initial render + effect run should be enough. + // If rerender was for a specific state change, it might still be needed. + // For now, let's test if the initial effect run (covered by the first act) is sufficient. + // If not, we can add back: await act(async () => { rerender({}); }); + + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']); + + await waitFor(() => { + expect(localMockSendMessageStream).toHaveBeenCalledTimes(1); + }); + + const expectedMergedResponse = mergePartListUnions([ + toolCall1ResponseParts, + toolCall2ResponseParts, + ]); + expect(localMockSendMessageStream).toHaveBeenCalledWith( + expect.anything(), + expectedMergedResponse, + expect.anything(), + ); + }); +}); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 35e5a26a..b6ef1481 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -530,7 +530,10 @@ export const useGeminiStream = ( }, ); - if (completedAndReadyToSubmitTools.length > 0) { + if ( + completedAndReadyToSubmitTools.length > 0 && + completedAndReadyToSubmitTools.length === toolCalls.length + ) { const responsesToSend: PartListUnion[] = completedAndReadyToSubmitTools.map( (toolCall) => toolCall.response.responseParts, |
