diff options
Diffstat (limited to 'packages/core/src/code_assist/oauth2.ts')
| -rw-r--r-- | packages/core/src/code_assist/oauth2.ts | 104 |
1 files changed, 61 insertions, 43 deletions
diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 9e15f65b..6527f957 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL = const GEMINI_DIR = '.gemini'; const CREDENTIAL_FILENAME = 'oauth_creds.json'; -export async function getOauthClient(): Promise<OAuth2Client> { - try { - return await getCachedCredentialClient(); - } catch (_) { - const loggedInClient = await webLoginClient(); - await setCachedCredentials(loggedInClient.credentials); - return loggedInClient; - } +/** + * An Authentication URL for updating the credentials of a Oauth2Client + * as well as a promise that will resolve when the credentials have + * been refreshed (or which throws error when refreshing credentials failed). + */ +export interface OauthWebLogin { + authUrl: string; + loginCompletePromise: Promise<void>; } -async function webLoginClient(): Promise<OAuth2Client> { - const port = await getAvailablePort(); - const oAuth2Client = new OAuth2Client({ +export async function getOauthClient(): Promise<OAuth2Client> { + const client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, clientSecret: OAUTH_CLIENT_SECRET, - redirectUri: `http://localhost:${port}/oauth2callback`, }); - return new Promise((resolve, reject) => { - const state = crypto.randomBytes(32).toString('hex'); - const authURL: string = oAuth2Client.generateAuthUrl({ - access_type: 'offline', - scope: OAUTH_SCOPE, - state, - }); - console.log( - `\n\nCode Assist login required.\n` + - `Attempting to open authentication page in your browser.\n` + - `Otherwise navigate to:\n\n${authURL}\n\n`, - ); - open(authURL); - console.log('Waiting for authentication...'); + if (await loadCachedCredentials(client)) { + // Found valid cached credentials. + return client; + } + const webLogin = await authWithWeb(client); + + console.log( + `\n\nCode Assist login required.\n` + + `Attempting to open authentication page in your browser.\n` + + `Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`, + ); + await open(webLogin.authUrl); + console.log('Waiting for authentication...'); + + await webLogin.loginCompletePromise; + + return client; +} + +async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> { + const port = await getAvailablePort(); + const redirectUri = `http://localhost:${port}/oauth2callback`; + const state = crypto.randomBytes(32).toString('hex'); + const authUrl: string = client.generateAuthUrl({ + redirect_uri: redirectUri, + access_type: 'offline', + scope: OAUTH_SCOPE, + state, + }); + + const loginCompletePromise = new Promise<void>((resolve, reject) => { const server = http.createServer(async (req, res) => { try { if (req.url!.indexOf('/oauth2callback') === -1) { @@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> { reject(new Error('State mismatch. Possible CSRF attack')); } else if (qs.get('code')) { - const code: string = qs.get('code')!; - const { tokens } = await oAuth2Client.getToken(code); - oAuth2Client.setCredentials(tokens); + const { tokens } = await client.getToken({ + code: qs.get('code')!, + redirect_uri: redirectUri, + }); + client.setCredentials(tokens); + await cacheCredentials(client.credentials); res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL }); res.end(); - resolve(oAuth2Client); + resolve(); } else { reject(new Error('No code found in request')); } @@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> { }); server.listen(port); }); + + return { + authUrl, + loginCompletePromise, + }; } -function getAvailablePort(): Promise<number> { +export function getAvailablePort(): Promise<number> { return new Promise((resolve, reject) => { let port = 0; try { @@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> { }); } -async function getCachedCredentialClient(): Promise<OAuth2Client> { +async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> { try { const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8'); - const oAuth2Client = new OAuth2Client({ - clientId: OAUTH_CLIENT_ID, - clientSecret: OAUTH_CLIENT_SECRET, - }); - oAuth2Client.setCredentials(JSON.parse(creds)); + client.setCredentials(JSON.parse(creds)); // This will either return the existing token or refresh it. - await oAuth2Client.getAccessToken(); - // If we are here, the token is valid. - return oAuth2Client; + await client.getAccessToken(); + + return true; } catch (_) { - // Could not load credentials. - throw new Error('Could not load credentials'); + return false; } } -async function setCachedCredentials(credentials: Credentials) { +async function cacheCredentials(credentials: Credentials) { const filePath = getCachedCredentialPath(); await fs.mkdir(path.dirname(filePath), { recursive: true }); |
