summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useToolScheduler.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useToolScheduler.ts')
-rw-r--r--packages/cli/src/ui/hooks/useToolScheduler.ts69
1 files changed, 63 insertions, 6 deletions
diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts
index 7d8cfbe4..e6e80785 100644
--- a/packages/cli/src/ui/hooks/useToolScheduler.ts
+++ b/packages/cli/src/ui/hooks/useToolScheduler.ts
@@ -13,7 +13,7 @@ import {
ToolCallConfirmationDetails,
ToolResult,
} from '@gemini-code/server';
-import { Part } from '@google/genai';
+import { Part, PartUnion, PartListUnion } from '@google/genai';
import { useCallback, useEffect, useState } from 'react';
import {
HistoryItemToolGroup,
@@ -88,6 +88,60 @@ export type CompletedToolCall =
| CancelledToolCall
| ErroredToolCall;
+/**
+ * Formats a PartListUnion response from a tool into JSON suitable for a Gemini
+ * FunctionResponse and additional Parts to include after that response.
+ *
+ * This is required because FunctionReponse appears to only support JSON
+ * and not arbitrary parts. Including parts like inlineData or fileData
+ * directly in a FunctionResponse confuses the model resulting in a failure
+ * to interpret the multimodal content and context window exceeded errors.
+ */
+
+export function formatLlmContentForFunctionResponse(
+ llmContent: PartListUnion,
+): {
+ functionResponseJson: Record<string, string>;
+ additionalParts: PartUnion[];
+} {
+ const additionalParts: PartUnion[] = [];
+ let functionResponseJson: Record<string, string>;
+
+ if (Array.isArray(llmContent) && llmContent.length === 1) {
+ // Ensure that length 1 arrays are treated as a single Part.
+ llmContent = llmContent[0];
+ }
+
+ if (typeof llmContent === 'string') {
+ functionResponseJson = { output: llmContent };
+ } else if (Array.isArray(llmContent)) {
+ functionResponseJson = { status: 'Tool execution succeeded.' };
+ additionalParts.push(...llmContent);
+ } else {
+ if (
+ llmContent.inlineData !== undefined ||
+ llmContent.fileData !== undefined
+ ) {
+ // For Parts like inlineData or fileData, use the returnDisplay as the textual output for the functionResponse.
+ // The actual Part will be added to additionalParts.
+ functionResponseJson = {
+ status: `Binary content of type ${llmContent.inlineData?.mimeType || llmContent.fileData?.mimeType || 'unknown'} was processed.`,
+ };
+ additionalParts.push(llmContent);
+ } else if (llmContent.text !== undefined) {
+ functionResponseJson = { output: llmContent.text };
+ } else {
+ functionResponseJson = { status: 'Tool execution succeeded.' };
+ additionalParts.push(llmContent);
+ }
+ }
+
+ return {
+ functionResponseJson,
+ additionalParts,
+ };
+}
+
export function useToolScheduler(
onComplete: (tools: CompletedToolCall[]) => void,
config: Config,
@@ -201,7 +255,7 @@ export function useToolScheduler(
status: 'cancelled',
response: {
callId: c.request.callId,
- responsePart: {
+ responseParts: {
functionResponse: {
id: c.request.callId,
name: c.request.name,
@@ -276,21 +330,24 @@ export function useToolScheduler(
.execute(t.request.args, signal, onOutputChunk)
.then((result: ToolResult) => {
if (signal.aborted) {
+ // TODO(jacobr): avoid stringifying the LLM content.
setToolCalls(
setStatus(callId, 'cancelled', String(result.llmContent)),
);
return;
}
+ const { functionResponseJson, additionalParts } =
+ formatLlmContentForFunctionResponse(result.llmContent);
const functionResponse: Part = {
functionResponse: {
name: t.request.name,
id: callId,
- response: { output: result.llmContent },
+ response: functionResponseJson,
},
};
const response: ToolCallResponseInfo = {
callId,
- responsePart: functionResponse,
+ responseParts: [functionResponse, ...additionalParts],
resultDisplay: result.returnDisplay,
error: undefined,
};
@@ -401,7 +458,7 @@ function setStatus(
status: 'cancelled',
response: {
callId: t.request.callId,
- responsePart: {
+ responseParts: {
functionResponse: {
id: t.request.callId,
name: t.request.name,
@@ -446,7 +503,7 @@ const toolErrorResponse = (
): ToolCallResponseInfo => ({
callId: request.callId,
error,
- responsePart: {
+ responseParts: {
functionResponse: {
id: request.callId,
name: request.name,