diff options
Diffstat (limited to 'packages/server/src/core')
| -rw-r--r-- | packages/server/src/core/client.ts | 13 | ||||
| -rw-r--r-- | packages/server/src/core/geminiChat.ts | 4 | ||||
| -rw-r--r-- | packages/server/src/core/turn.test.ts | 27 | ||||
| -rw-r--r-- | packages/server/src/core/turn.ts | 12 |
4 files changed, 44 insertions, 12 deletions
diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 341ce021..69b815ab 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -157,7 +157,7 @@ export class GeminiClient { async *sendMessageStream( chat: GeminiChat, request: PartListUnion, - signal?: AbortSignal, + signal: AbortSignal, turns: number = this.MAX_TURNS, ): AsyncGenerator<ServerGeminiStreamEvent> { if (!turns) { @@ -169,8 +169,8 @@ export class GeminiClient { for await (const event of resultStream) { yield event; } - if (!turn.pendingToolCalls.length) { - const nextSpeakerCheck = await checkNextSpeaker(chat, this); + if (!turn.pendingToolCalls.length && signal && !signal.aborted) { + const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1); @@ -181,6 +181,7 @@ export class GeminiClient { async generateJson( contents: Content[], schema: SchemaUnion, + abortSignal: AbortSignal, model: string = 'gemini-2.0-flash', config: GenerateContentConfig = {}, ): Promise<Record<string, unknown>> { @@ -188,6 +189,7 @@ export class GeminiClient { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { + abortSignal, ...this.generateContentConfig, ...config, }; @@ -232,6 +234,11 @@ export class GeminiClient { ); } } catch (error) { + if (abortSignal.aborted) { + // Regular cancellation error, fail normally + throw error; + } + // Avoid double reporting for the empty response case handled above if ( error instanceof Error && diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts index c971e2cc..5ba8ce2d 100644 --- a/packages/server/src/core/geminiChat.ts +++ b/packages/server/src/core/geminiChat.ts @@ -155,7 +155,7 @@ export class GeminiChat { const responsePromise = this.modelsModule.generateContent({ model: this.model, contents: this.getHistory(true).concat(userContent), - config: params.config ?? this.config, + config: { ...this.config, ...params.config }, }); this.sendPromise = (async () => { const response = await responsePromise; @@ -219,7 +219,7 @@ export class GeminiChat { const streamResponse = this.modelsModule.generateContentStream({ model: this.model, contents: this.getHistory(true).concat(userContent), - config: params.config ?? this.config, + config: { ...this.config, ...params.config }, }); // Resolve the internal tracking of send completion promise - `sendPromise` // for both success and failure response. The actual failure is still diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts index 44bb983f..8fb3a4c1 100644 --- a/packages/server/src/core/turn.test.ts +++ b/packages/server/src/core/turn.test.ts @@ -85,11 +85,17 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } - expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts }); + expect(mockSendMessageStream).toHaveBeenCalledWith({ + message: reqParts, + config: { abortSignal: expect.any(AbortSignal) }, + }); expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' world' }, @@ -110,7 +116,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Use tools' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -179,7 +188,10 @@ describe('Turn', () => { mockGetHistory.mockReturnValue(historyContent); const events = []; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -210,7 +222,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -261,7 +276,7 @@ describe('Turn', () => { })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const _ of turn.run(reqParts)) { + for await (const _ of turn.run(reqParts, new AbortController().signal)) { // consume stream } expect(turn.getDebugResponses()).toEqual([resp1, resp2]); diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index d5c7eb58..97e93f59 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -32,6 +32,7 @@ export interface ServerTool { ): Promise<ToolResult>; shouldConfirmExecute( params: Record<string, unknown>, + abortSignal: AbortSignal, ): Promise<ToolCallConfirmationDetails | false>; } @@ -120,11 +121,14 @@ export class Turn { // The run method yields simpler events suitable for server logic async *run( req: PartListUnion, - signal?: AbortSignal, + signal: AbortSignal, ): AsyncGenerator<ServerGeminiStreamEvent> { try { const responseStream = await this.chat.sendMessageStream({ message: req, + config: { + abortSignal: signal, + }, }); for await (const resp of responseStream) { @@ -150,6 +154,12 @@ export class Turn { } } } catch (error) { + if (signal.aborted) { + yield { type: GeminiEventType.UserCancelled }; + // Regular cancellation error, fail gracefully. + return; + } + const contextForReport = [...this.chat.getHistory(/*curated*/ true), req]; await reportError( error, |
