diff options
| author | Tommaso Sciortino <[email protected]> | 2025-06-18 16:34:00 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-18 16:34:00 -0700 |
| commit | 8bc3b415c973794654d64d434949a93fb3239acb (patch) | |
| tree | 464339fab8b05f754ae158fed50855b2a5f61231 /packages/core/src/code_assist | |
| parent | b96fbd913e8449c99ed7b95920652acdce5dd779 (diff) | |
Refactor in preparation for Reauth (#1196)
Diffstat (limited to 'packages/core/src/code_assist')
| -rw-r--r-- | packages/core/src/code_assist/oauth2.test.ts | 9 | ||||
| -rw-r--r-- | packages/core/src/code_assist/oauth2.ts | 104 |
2 files changed, 68 insertions, 45 deletions
diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index 47bd45b3..0f5b791b 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -73,8 +73,10 @@ describe('oauth2', () => { (resolve) => (serverListeningCallback = resolve), ); + let capturedPort = 0; const mockHttpServer = { listen: vi.fn((port: number, callback?: () => void) => { + capturedPort = port; if (callback) { callback(); } @@ -86,7 +88,7 @@ describe('oauth2', () => { } }), on: vi.fn(), - address: () => ({ port: 1234 }), + address: () => ({ port: capturedPort }), }; vi.mocked(http.createServer).mockImplementation((cb) => { requestCallback = cb as http.RequestListener< @@ -115,7 +117,10 @@ describe('oauth2', () => { expect(client).toBe(mockOAuth2Client); expect(open).toHaveBeenCalledWith(mockAuthUrl); - expect(mockGetToken).toHaveBeenCalledWith(mockCode); + expect(mockGetToken).toHaveBeenCalledWith({ + code: mockCode, + redirect_uri: `http://localhost:${capturedPort}/oauth2callback`, + }); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 9e15f65b..6527f957 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL = const GEMINI_DIR = '.gemini'; const CREDENTIAL_FILENAME = 'oauth_creds.json'; -export async function getOauthClient(): Promise<OAuth2Client> { - try { - return await getCachedCredentialClient(); - } catch (_) { - const loggedInClient = await webLoginClient(); - await setCachedCredentials(loggedInClient.credentials); - return loggedInClient; - } +/** + * An Authentication URL for updating the credentials of a Oauth2Client + * as well as a promise that will resolve when the credentials have + * been refreshed (or which throws error when refreshing credentials failed). + */ +export interface OauthWebLogin { + authUrl: string; + loginCompletePromise: Promise<void>; } -async function webLoginClient(): Promise<OAuth2Client> { - const port = await getAvailablePort(); - const oAuth2Client = new OAuth2Client({ +export async function getOauthClient(): Promise<OAuth2Client> { + const client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, clientSecret: OAUTH_CLIENT_SECRET, - redirectUri: `http://localhost:${port}/oauth2callback`, }); - return new Promise((resolve, reject) => { - const state = crypto.randomBytes(32).toString('hex'); - const authURL: string = oAuth2Client.generateAuthUrl({ - access_type: 'offline', - scope: OAUTH_SCOPE, - state, - }); - console.log( - `\n\nCode Assist login required.\n` + - `Attempting to open authentication page in your browser.\n` + - `Otherwise navigate to:\n\n${authURL}\n\n`, - ); - open(authURL); - console.log('Waiting for authentication...'); + if (await loadCachedCredentials(client)) { + // Found valid cached credentials. + return client; + } + const webLogin = await authWithWeb(client); + + console.log( + `\n\nCode Assist login required.\n` + + `Attempting to open authentication page in your browser.\n` + + `Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`, + ); + await open(webLogin.authUrl); + console.log('Waiting for authentication...'); + + await webLogin.loginCompletePromise; + + return client; +} + +async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> { + const port = await getAvailablePort(); + const redirectUri = `http://localhost:${port}/oauth2callback`; + const state = crypto.randomBytes(32).toString('hex'); + const authUrl: string = client.generateAuthUrl({ + redirect_uri: redirectUri, + access_type: 'offline', + scope: OAUTH_SCOPE, + state, + }); + + const loginCompletePromise = new Promise<void>((resolve, reject) => { const server = http.createServer(async (req, res) => { try { if (req.url!.indexOf('/oauth2callback') === -1) { @@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> { reject(new Error('State mismatch. Possible CSRF attack')); } else if (qs.get('code')) { - const code: string = qs.get('code')!; - const { tokens } = await oAuth2Client.getToken(code); - oAuth2Client.setCredentials(tokens); + const { tokens } = await client.getToken({ + code: qs.get('code')!, + redirect_uri: redirectUri, + }); + client.setCredentials(tokens); + await cacheCredentials(client.credentials); res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL }); res.end(); - resolve(oAuth2Client); + resolve(); } else { reject(new Error('No code found in request')); } @@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> { }); server.listen(port); }); + + return { + authUrl, + loginCompletePromise, + }; } -function getAvailablePort(): Promise<number> { +export function getAvailablePort(): Promise<number> { return new Promise((resolve, reject) => { let port = 0; try { @@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> { }); } -async function getCachedCredentialClient(): Promise<OAuth2Client> { +async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> { try { const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8'); - const oAuth2Client = new OAuth2Client({ - clientId: OAUTH_CLIENT_ID, - clientSecret: OAUTH_CLIENT_SECRET, - }); - oAuth2Client.setCredentials(JSON.parse(creds)); + client.setCredentials(JSON.parse(creds)); // This will either return the existing token or refresh it. - await oAuth2Client.getAccessToken(); - // If we are here, the token is valid. - return oAuth2Client; + await client.getAccessToken(); + + return true; } catch (_) { - // Could not load credentials. - throw new Error('Could not load credentials'); + return false; } } -async function setCachedCredentials(credentials: Credentials) { +async function cacheCredentials(credentials: Credentials) { const filePath = getCachedCredentialPath(); await fs.mkdir(path.dirname(filePath), { recursive: true }); |
