summaryrefslogtreecommitdiff
path: root/packages/server/src/core/turn.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/server/src/core/turn.test.ts')
-rw-r--r--packages/server/src/core/turn.test.ts27
1 files changed, 21 insertions, 6 deletions
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts
index 44bb983f..8fb3a4c1 100644
--- a/packages/server/src/core/turn.test.ts
+++ b/packages/server/src/core/turn.test.ts
@@ -85,11 +85,17 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
- expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
+ expect(mockSendMessageStream).toHaveBeenCalledWith({
+ message: reqParts,
+ config: { abortSignal: expect.any(AbortSignal) },
+ });
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
@@ -110,7 +116,10 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Use tools' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -179,7 +188,10 @@ describe('Turn', () => {
mockGetHistory.mockReturnValue(historyContent);
const events = [];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -210,7 +222,10 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -261,7 +276,7 @@ describe('Turn', () => {
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const _ of turn.run(reqParts)) {
+ for await (const _ of turn.run(reqParts, new AbortController().signal)) {
// consume stream
}
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);