diff options
Diffstat (limited to 'packages/server/src/core/turn.test.ts')
| -rw-r--r-- | packages/server/src/core/turn.test.ts | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts index 44bb983f..8fb3a4c1 100644 --- a/packages/server/src/core/turn.test.ts +++ b/packages/server/src/core/turn.test.ts @@ -85,11 +85,17 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } - expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts }); + expect(mockSendMessageStream).toHaveBeenCalledWith({ + message: reqParts, + config: { abortSignal: expect.any(AbortSignal) }, + }); expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' world' }, @@ -110,7 +116,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Use tools' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -179,7 +188,10 @@ describe('Turn', () => { mockGetHistory.mockReturnValue(historyContent); const events = []; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -210,7 +222,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -261,7 +276,7 @@ describe('Turn', () => { })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const _ of turn.run(reqParts)) { + for await (const _ of turn.run(reqParts, new AbortController().signal)) { // consume stream } expect(turn.getDebugResponses()).toEqual([resp1, resp2]); |
