diff options
Diffstat (limited to 'packages/core/src/code_assist/oauth2.ts')
| -rw-r--r-- | packages/core/src/code_assist/oauth2.ts | 61 |
1 files changed, 55 insertions, 6 deletions
diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index af87caea..7d65d260 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -10,6 +10,8 @@ import url from 'url'; import crypto from 'crypto'; import * as net from 'net'; import open from 'open'; +import path from 'node:path'; +import { promises as fs } from 'node:fs'; // OAuth Client ID used to initiate OAuth2Client class. const OAUTH_CLIENT_ID = @@ -36,7 +38,50 @@ const SIGN_IN_SUCCESS_URL = const SIGN_IN_FAILURE_URL = 'https://developers.google.com/gemini-code-assist/auth_failure_gemini'; -export async function loginWithOauth(): Promise<OAuth2Client> { +const GEMINI_DIR = '.gemini'; +const CREDENTIAL_FILENAME = 'oauth_creds.json'; + +export async function getCachedCredentialClient(): Promise<OAuth2Client> { + try { + const creds = await fs.readFile( + path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME), + 'utf-8', + ); + + const oAuth2Client = new OAuth2Client({ + clientId: OAUTH_CLIENT_ID, + clientSecret: OAUTH_CLIENT_SECRET, + }); + oAuth2Client.setCredentials(JSON.parse(creds)); + // This will either return the existing token or refresh it. + await oAuth2Client.getAccessToken(); + // If we are here, the token is valid. + return oAuth2Client; + } catch (_) { + // Could not load credentials. + throw new Error('Could not load credentials'); + } +} + +export async function clearCachedCredentials(): Promise<void> { + await fs.rm(path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME)); +} + +export async function getOauthClient(): Promise<OAuth2Client> { + try { + return await getCachedCredentialClient(); + } catch (_) { + const loggedInClient = await webLoginClient(); + await fs.mkdir(path.join(process.cwd(), GEMINI_DIR), { recursive: true }); + await fs.writeFile( + path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME), + JSON.stringify(loggedInClient.credentials, null, 2), + ); + return loggedInClient; + } +} + +export async function webLoginClient(): Promise<OAuth2Client> { const port = await getAvailablePort(); const oAuth2Client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, @@ -51,33 +96,37 @@ export async function loginWithOauth(): Promise<OAuth2Client> { scope: OAUTH_SCOPE, state, }); + console.log( + `\n\nCode Assist login required.\n` + + `Attempting to open authentication page in your browser.\n` + + `Otherwise navigate to:\n\n${authURL}\n\n`, + ); open(authURL); + console.log('Waiting for authentication...'); const server = http.createServer(async (req, res) => { try { if (req.url!.indexOf('/oauth2callback') === -1) { - console.log('Unexpected request:', req.url); res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL }); res.end(); reject(new Error('Unexpected request: ' + req.url)); } // acquire the code from the querystring, and close the web server. const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams; - console.log('Processing request:', qs); if (qs.get('error')) { res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL }); res.end(); + reject(new Error(`Error during authentication: ${qs.get('error')}`)); } else if (qs.get('state') !== state) { res.end('State mismatch. Possible CSRF attack'); + reject(new Error('State mismatch. Possible CSRF attack')); } else if (qs.get('code')) { const code: string = qs.get('code')!; - console.log(); const { tokens } = await oAuth2Client.getToken(code); - console.log('Logged in! Tokens:\n\n', tokens); - oAuth2Client.setCredentials(tokens); + res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL }); res.end(); resolve(oAuth2Client); |
