summaryrefslogtreecommitdiff
path: root/packages/core/src/code_assist/oauth2.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/code_assist/oauth2.ts')
-rw-r--r--packages/core/src/code_assist/oauth2.ts104
1 files changed, 61 insertions, 43 deletions
diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts
index 9e15f65b..6527f957 100644
--- a/packages/core/src/code_assist/oauth2.ts
+++ b/packages/core/src/code_assist/oauth2.ts
@@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL =
const GEMINI_DIR = '.gemini';
const CREDENTIAL_FILENAME = 'oauth_creds.json';
-export async function getOauthClient(): Promise<OAuth2Client> {
- try {
- return await getCachedCredentialClient();
- } catch (_) {
- const loggedInClient = await webLoginClient();
- await setCachedCredentials(loggedInClient.credentials);
- return loggedInClient;
- }
+/**
+ * An Authentication URL for updating the credentials of a Oauth2Client
+ * as well as a promise that will resolve when the credentials have
+ * been refreshed (or which throws error when refreshing credentials failed).
+ */
+export interface OauthWebLogin {
+ authUrl: string;
+ loginCompletePromise: Promise<void>;
}
-async function webLoginClient(): Promise<OAuth2Client> {
- const port = await getAvailablePort();
- const oAuth2Client = new OAuth2Client({
+export async function getOauthClient(): Promise<OAuth2Client> {
+ const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
- redirectUri: `http://localhost:${port}/oauth2callback`,
});
- return new Promise((resolve, reject) => {
- const state = crypto.randomBytes(32).toString('hex');
- const authURL: string = oAuth2Client.generateAuthUrl({
- access_type: 'offline',
- scope: OAUTH_SCOPE,
- state,
- });
- console.log(
- `\n\nCode Assist login required.\n` +
- `Attempting to open authentication page in your browser.\n` +
- `Otherwise navigate to:\n\n${authURL}\n\n`,
- );
- open(authURL);
- console.log('Waiting for authentication...');
+ if (await loadCachedCredentials(client)) {
+ // Found valid cached credentials.
+ return client;
+ }
+ const webLogin = await authWithWeb(client);
+
+ console.log(
+ `\n\nCode Assist login required.\n` +
+ `Attempting to open authentication page in your browser.\n` +
+ `Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
+ );
+ await open(webLogin.authUrl);
+ console.log('Waiting for authentication...');
+
+ await webLogin.loginCompletePromise;
+
+ return client;
+}
+
+async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
+ const port = await getAvailablePort();
+ const redirectUri = `http://localhost:${port}/oauth2callback`;
+ const state = crypto.randomBytes(32).toString('hex');
+ const authUrl: string = client.generateAuthUrl({
+ redirect_uri: redirectUri,
+ access_type: 'offline',
+ scope: OAUTH_SCOPE,
+ state,
+ });
+
+ const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
@@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> {
reject(new Error('State mismatch. Possible CSRF attack'));
} else if (qs.get('code')) {
- const code: string = qs.get('code')!;
- const { tokens } = await oAuth2Client.getToken(code);
- oAuth2Client.setCredentials(tokens);
+ const { tokens } = await client.getToken({
+ code: qs.get('code')!,
+ redirect_uri: redirectUri,
+ });
+ client.setCredentials(tokens);
+ await cacheCredentials(client.credentials);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
- resolve(oAuth2Client);
+ resolve();
} else {
reject(new Error('No code found in request'));
}
@@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> {
});
server.listen(port);
});
+
+ return {
+ authUrl,
+ loginCompletePromise,
+ };
}
-function getAvailablePort(): Promise<number> {
+export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
@@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> {
});
}
-async function getCachedCredentialClient(): Promise<OAuth2Client> {
+async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
try {
const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8');
- const oAuth2Client = new OAuth2Client({
- clientId: OAUTH_CLIENT_ID,
- clientSecret: OAUTH_CLIENT_SECRET,
- });
- oAuth2Client.setCredentials(JSON.parse(creds));
+ client.setCredentials(JSON.parse(creds));
// This will either return the existing token or refresh it.
- await oAuth2Client.getAccessToken();
- // If we are here, the token is valid.
- return oAuth2Client;
+ await client.getAccessToken();
+
+ return true;
} catch (_) {
- // Could not load credentials.
- throw new Error('Could not load credentials');
+ return false;
}
}
-async function setCachedCredentials(credentials: Credentials) {
+async function cacheCredentials(credentials: Credentials) {
const filePath = getCachedCredentialPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });