diff options
| author | Gal Zahavi <[email protected]> | 2025-08-19 16:03:51 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-08-19 23:03:51 +0000 |
| commit | f1575f6d8de2f4efa0805a2d11a4a421a1a8228f (patch) | |
| tree | 8977235b9a42983de3e76189f25ff055e9d28a83 /packages/core/src/services/shellExecutionService.ts | |
| parent | 0cc2a1e7ef904294fff982a4d75bf098b5b262f7 (diff) | |
feat(core): refactor shell execution to use node-pty (#6491)
Co-authored-by: Jacob Richman <[email protected]>
Diffstat (limited to 'packages/core/src/services/shellExecutionService.ts')
| -rw-r--r-- | packages/core/src/services/shellExecutionService.ts | 495 |
1 files changed, 351 insertions, 144 deletions
diff --git a/packages/core/src/services/shellExecutionService.ts b/packages/core/src/services/shellExecutionService.ts index 3749fcf6..59e998bd 100644 --- a/packages/core/src/services/shellExecutionService.ts +++ b/packages/core/src/services/shellExecutionService.ts @@ -4,35 +4,47 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { spawn } from 'child_process'; +import { getPty, PtyImplementation } from '../utils/getPty.js'; +import { spawn as cpSpawn } from 'child_process'; import { TextDecoder } from 'util'; import os from 'os'; -import stripAnsi from 'strip-ansi'; import { getCachedEncodingForBuffer } from '../utils/systemEncoding.js'; import { isBinary } from '../utils/textUtils.js'; +import pkg from '@xterm/headless'; +import stripAnsi from 'strip-ansi'; +const { Terminal } = pkg; const SIGKILL_TIMEOUT_MS = 200; +// @ts-expect-error getFullText is not a public API. +const getFullText = (terminal: Terminal) => { + const buffer = terminal.buffer.active; + const lines: string[] = []; + for (let i = 0; i < buffer.length; i++) { + const line = buffer.getLine(i); + lines.push(line ? line.translateToString(true) : ''); + } + return lines.join('\n').trim(); +}; + /** A structured result from a shell command execution. */ export interface ShellExecutionResult { /** The raw, unprocessed output buffer. */ rawOutput: Buffer; - /** The combined, decoded stdout and stderr as a string. */ + /** The combined, decoded output as a string. */ output: string; - /** The decoded stdout as a string. */ - stdout: string; - /** The decoded stderr as a string. */ - stderr: string; /** The process exit code, or null if terminated by a signal. */ exitCode: number | null; /** The signal that terminated the process, if any. */ - signal: NodeJS.Signals | null; + signal: number | null; /** An error object if the process failed to spawn. */ error: Error | null; /** A boolean indicating if the command was aborted by the user. */ aborted: boolean; /** The process ID of the spawned shell. */ pid: number | undefined; + /** The method used to execute the shell command. */ + executionMethod: 'lydell-node-pty' | 'node-pty' | 'child_process' | 'none'; } /** A handle for an ongoing shell execution. */ @@ -50,8 +62,6 @@ export type ShellOutputEvent = | { /** The event contains a chunk of output data. */ type: 'data'; - /** The stream from which the data originated. */ - stream: 'stdout' | 'stderr'; /** The decoded string chunk. */ chunk: string; } @@ -73,7 +83,7 @@ export type ShellOutputEvent = */ export class ShellExecutionService { /** - * Executes a shell command using `spawn`, capturing all output and lifecycle events. + * Executes a shell command using `node-pty`, capturing all output and lifecycle events. * * @param commandToExecute The exact command string to run. * @param cwd The working directory to execute the command in. @@ -82,172 +92,369 @@ export class ShellExecutionService { * @returns An object containing the process ID (pid) and a promise that * resolves with the complete execution result. */ - static execute( + static async execute( commandToExecute: string, cwd: string, onOutputEvent: (event: ShellOutputEvent) => void, abortSignal: AbortSignal, - ): ShellExecutionHandle { - const isWindows = os.platform() === 'win32'; + shouldUseNodePty: boolean, + terminalColumns?: number, + terminalRows?: number, + ): Promise<ShellExecutionHandle> { + if (shouldUseNodePty) { + const ptyInfo = await getPty(); + if (ptyInfo) { + try { + return this.executeWithPty( + commandToExecute, + cwd, + onOutputEvent, + abortSignal, + terminalColumns, + terminalRows, + ptyInfo, + ); + } catch (_e) { + // Fallback to child_process + } + } + } - const child = spawn(commandToExecute, [], { + return this.childProcessFallback( + commandToExecute, cwd, - stdio: ['ignore', 'pipe', 'pipe'], - // Use bash unless in Windows (since it doesn't support bash). - // For windows, just use the default. - shell: isWindows ? true : 'bash', - // Use process groups on non-Windows for robust killing. - // Windows process termination is handled by `taskkill /t`. - detached: !isWindows, - env: { - ...process.env, - GEMINI_CLI: '1', - }, - }); + onOutputEvent, + abortSignal, + ); + } + + private static childProcessFallback( + commandToExecute: string, + cwd: string, + onOutputEvent: (event: ShellOutputEvent) => void, + abortSignal: AbortSignal, + ): ShellExecutionHandle { + try { + const isWindows = os.platform() === 'win32'; - const result = new Promise<ShellExecutionResult>((resolve) => { - // Use decoders to handle multi-byte characters safely (for streaming output). - let stdoutDecoder: TextDecoder | null = null; - let stderrDecoder: TextDecoder | null = null; + const child = cpSpawn(commandToExecute, [], { + cwd, + stdio: ['ignore', 'pipe', 'pipe'], + shell: isWindows ? true : 'bash', + detached: !isWindows, + env: { + ...process.env, + GEMINI_CLI: '1', + TERM: 'xterm-256color', + PAGER: 'cat', + }, + }); - let stdout = ''; - let stderr = ''; - const outputChunks: Buffer[] = []; - let error: Error | null = null; - let exited = false; + const result = new Promise<ShellExecutionResult>((resolve) => { + let stdoutDecoder: TextDecoder | null = null; + let stderrDecoder: TextDecoder | null = null; - let isStreamingRawContent = true; - const MAX_SNIFF_SIZE = 4096; - let sniffedBytes = 0; + let stdout = ''; + let stderr = ''; + const outputChunks: Buffer[] = []; + let error: Error | null = null; + let exited = false; - const handleOutput = (data: Buffer, stream: 'stdout' | 'stderr') => { - if (!stdoutDecoder || !stderrDecoder) { - const encoding = getCachedEncodingForBuffer(data); - try { - stdoutDecoder = new TextDecoder(encoding); - stderrDecoder = new TextDecoder(encoding); - } catch { - // If the encoding is not supported, fall back to utf-8. - // This can happen on some platforms for certain encodings like 'utf-32le'. - stdoutDecoder = new TextDecoder('utf-8'); - stderrDecoder = new TextDecoder('utf-8'); + let isStreamingRawContent = true; + const MAX_SNIFF_SIZE = 4096; + let sniffedBytes = 0; + + const handleOutput = (data: Buffer, stream: 'stdout' | 'stderr') => { + if (!stdoutDecoder || !stderrDecoder) { + const encoding = getCachedEncodingForBuffer(data); + try { + stdoutDecoder = new TextDecoder(encoding); + stderrDecoder = new TextDecoder(encoding); + } catch { + stdoutDecoder = new TextDecoder('utf-8'); + stderrDecoder = new TextDecoder('utf-8'); + } } - } - outputChunks.push(data); + outputChunks.push(data); + + if (isStreamingRawContent && sniffedBytes < MAX_SNIFF_SIZE) { + const sniffBuffer = Buffer.concat(outputChunks.slice(0, 20)); + sniffedBytes = sniffBuffer.length; - // Binary detection logic. This only runs until we've made a determination. - if (isStreamingRawContent && sniffedBytes < MAX_SNIFF_SIZE) { - const sniffBuffer = Buffer.concat(outputChunks.slice(0, 20)); - sniffedBytes = sniffBuffer.length; + if (isBinary(sniffBuffer)) { + isStreamingRawContent = false; + onOutputEvent({ type: 'binary_detected' }); + } + } + + const decoder = stream === 'stdout' ? stdoutDecoder : stderrDecoder; + const decodedChunk = decoder.decode(data, { stream: true }); + const strippedChunk = stripAnsi(decodedChunk); - if (isBinary(sniffBuffer)) { - // Change state to stop streaming raw content. - isStreamingRawContent = false; - onOutputEvent({ type: 'binary_detected' }); + if (stream === 'stdout') { + stdout += strippedChunk; + } else { + stderr += strippedChunk; } - } - const decodedChunk = - stream === 'stdout' - ? stdoutDecoder.decode(data, { stream: true }) - : stderrDecoder.decode(data, { stream: true }); - const strippedChunk = stripAnsi(decodedChunk); + if (isStreamingRawContent) { + onOutputEvent({ type: 'data', chunk: strippedChunk }); + } else { + const totalBytes = outputChunks.reduce( + (sum, chunk) => sum + chunk.length, + 0, + ); + onOutputEvent({ + type: 'binary_progress', + bytesReceived: totalBytes, + }); + } + }; - if (stream === 'stdout') { - stdout += strippedChunk; - } else { - stderr += strippedChunk; - } + const handleExit = ( + code: number | null, + signal: NodeJS.Signals | null, + ) => { + const { finalBuffer } = cleanup(); + // Ensure we don't add an extra newline if stdout already ends with one. + const separator = stdout.endsWith('\n') ? '' : '\n'; + const combinedOutput = + stdout + (stderr ? (stdout ? separator : '') + stderr : ''); - if (isStreamingRawContent) { - onOutputEvent({ type: 'data', stream, chunk: strippedChunk }); - } else { - const totalBytes = outputChunks.reduce( - (sum, chunk) => sum + chunk.length, - 0, - ); - onOutputEvent({ type: 'binary_progress', bytesReceived: totalBytes }); - } - }; + resolve({ + rawOutput: finalBuffer, + output: combinedOutput.trim(), + exitCode: code, + signal: signal ? os.constants.signals[signal] : null, + error, + aborted: abortSignal.aborted, + pid: child.pid, + executionMethod: 'child_process', + }); + }; - child.stdout.on('data', (data) => handleOutput(data, 'stdout')); - child.stderr.on('data', (data) => handleOutput(data, 'stderr')); - child.on('error', (err) => { - const { stdout, stderr, finalBuffer } = cleanup(); - error = err; - resolve({ - error, - stdout, - stderr, - rawOutput: finalBuffer, - output: stdout + (stderr ? `\n${stderr}` : ''), - exitCode: 1, - signal: null, - aborted: false, - pid: child.pid, + child.stdout.on('data', (data) => handleOutput(data, 'stdout')); + child.stderr.on('data', (data) => handleOutput(data, 'stderr')); + child.on('error', (err) => { + error = err; + handleExit(1, null); }); - }); - const abortHandler = async () => { - if (child.pid && !exited) { - if (isWindows) { - spawn('taskkill', ['/pid', child.pid.toString(), '/f', '/t']); - } else { - try { - // Kill the entire process group (negative PID). - // SIGTERM first, then SIGKILL if it doesn't die. - process.kill(-child.pid, 'SIGTERM'); - await new Promise((res) => setTimeout(res, SIGKILL_TIMEOUT_MS)); - if (!exited) { - process.kill(-child.pid, 'SIGKILL'); + const abortHandler = async () => { + if (child.pid && !exited) { + if (isWindows) { + cpSpawn('taskkill', ['/pid', child.pid.toString(), '/f', '/t']); + } else { + try { + process.kill(-child.pid, 'SIGTERM'); + await new Promise((res) => setTimeout(res, SIGKILL_TIMEOUT_MS)); + if (!exited) { + process.kill(-child.pid, 'SIGKILL'); + } + } catch (_e) { + if (!exited) child.kill('SIGKILL'); } - } catch (_e) { - // Fall back to killing just the main process if group kill fails. - if (!exited) child.kill('SIGKILL'); } } + }; + + abortSignal.addEventListener('abort', abortHandler, { once: true }); + + child.on('exit', (code, signal) => { + handleExit(code, signal); + }); + + function cleanup() { + exited = true; + abortSignal.removeEventListener('abort', abortHandler); + if (stdoutDecoder) { + const remaining = stdoutDecoder.decode(); + if (remaining) { + stdout += stripAnsi(remaining); + } + } + if (stderrDecoder) { + const remaining = stderrDecoder.decode(); + if (remaining) { + stderr += stripAnsi(remaining); + } + } + + const finalBuffer = Buffer.concat(outputChunks); + + return { stdout, stderr, finalBuffer }; } + }); + + return { pid: child.pid, result }; + } catch (e) { + const error = e as Error; + return { + pid: undefined, + result: Promise.resolve({ + error, + rawOutput: Buffer.from(''), + output: '', + exitCode: 1, + signal: null, + aborted: false, + pid: undefined, + executionMethod: 'none', + }), }; + } + } - abortSignal.addEventListener('abort', abortHandler, { once: true }); + private static executeWithPty( + commandToExecute: string, + cwd: string, + onOutputEvent: (event: ShellOutputEvent) => void, + abortSignal: AbortSignal, + terminalColumns: number | undefined, + terminalRows: number | undefined, + ptyInfo: PtyImplementation | undefined, + ): ShellExecutionHandle { + try { + const cols = terminalColumns ?? 80; + const rows = terminalRows ?? 30; + const isWindows = os.platform() === 'win32'; + const shell = isWindows ? 'cmd.exe' : 'bash'; + const args = isWindows + ? ['/c', commandToExecute] + : ['-c', commandToExecute]; - child.on('exit', (code: number, signal: NodeJS.Signals) => { - const { stdout, stderr, finalBuffer } = cleanup(); + const ptyProcess = ptyInfo?.module.spawn(shell, args, { + cwd, + name: 'xterm-color', + cols, + rows, + env: { + ...process.env, + GEMINI_CLI: '1', + TERM: 'xterm-256color', + PAGER: 'cat', + }, + handleFlowControl: true, + }); - resolve({ - rawOutput: finalBuffer, - output: stdout + (stderr ? `\n${stderr}` : ''), - stdout, - stderr, - exitCode: code, - signal, - error, - aborted: abortSignal.aborted, - pid: child.pid, + const result = new Promise<ShellExecutionResult>((resolve) => { + const headlessTerminal = new Terminal({ + allowProposedApi: true, + cols, + rows, }); - }); + let processingChain = Promise.resolve(); + let decoder: TextDecoder | null = null; + let output = ''; + const outputChunks: Buffer[] = []; + const error: Error | null = null; + let exited = false; - /** - * Cleans up a process (and it's accompanying state) that is exiting or - * erroring and returns output formatted output buffers and strings - */ - function cleanup() { - exited = true; - abortSignal.removeEventListener('abort', abortHandler); - if (stdoutDecoder) { - stdout += stripAnsi(stdoutDecoder.decode()); - } - if (stderrDecoder) { - stderr += stripAnsi(stderrDecoder.decode()); - } + let isStreamingRawContent = true; + const MAX_SNIFF_SIZE = 4096; + let sniffedBytes = 0; - const finalBuffer = Buffer.concat(outputChunks); + const handleOutput = (data: Buffer) => { + processingChain = processingChain.then( + () => + new Promise<void>((resolve) => { + if (!decoder) { + const encoding = getCachedEncodingForBuffer(data); + try { + decoder = new TextDecoder(encoding); + } catch { + decoder = new TextDecoder('utf-8'); + } + } - return { stdout, stderr, finalBuffer }; - } - }); + outputChunks.push(data); + + if (isStreamingRawContent && sniffedBytes < MAX_SNIFF_SIZE) { + const sniffBuffer = Buffer.concat(outputChunks.slice(0, 20)); + sniffedBytes = sniffBuffer.length; + + if (isBinary(sniffBuffer)) { + isStreamingRawContent = false; + onOutputEvent({ type: 'binary_detected' }); + } + } + + if (isStreamingRawContent) { + const decodedChunk = decoder.decode(data, { stream: true }); + headlessTerminal.write(decodedChunk, () => { + const newStrippedOutput = getFullText(headlessTerminal); + output = newStrippedOutput; + onOutputEvent({ type: 'data', chunk: newStrippedOutput }); + resolve(); + }); + } else { + const totalBytes = outputChunks.reduce( + (sum, chunk) => sum + chunk.length, + 0, + ); + onOutputEvent({ + type: 'binary_progress', + bytesReceived: totalBytes, + }); + resolve(); + } + }), + ); + }; + + ptyProcess.onData((data: string) => { + const bufferData = Buffer.from(data, 'utf-8'); + handleOutput(bufferData); + }); + + ptyProcess.onExit( + ({ exitCode, signal }: { exitCode: number; signal?: number }) => { + exited = true; + abortSignal.removeEventListener('abort', abortHandler); + + processingChain.then(() => { + const finalBuffer = Buffer.concat(outputChunks); - return { pid: child.pid, result }; + resolve({ + rawOutput: finalBuffer, + output, + exitCode, + signal: signal ?? null, + error, + aborted: abortSignal.aborted, + pid: ptyProcess.pid, + executionMethod: ptyInfo?.name ?? 'node-pty', + }); + }); + }, + ); + + const abortHandler = async () => { + if (ptyProcess.pid && !exited) { + ptyProcess.kill('SIGHUP'); + } + }; + + abortSignal.addEventListener('abort', abortHandler, { once: true }); + }); + + return { pid: ptyProcess.pid, result }; + } catch (e) { + const error = e as Error; + return { + pid: undefined, + result: Promise.resolve({ + error, + rawOutput: Buffer.from(''), + output: '', + exitCode: 1, + signal: null, + aborted: false, + pid: undefined, + executionMethod: 'none', + }), + }; + } } } |
