summaryrefslogtreecommitdiff
path: root/packages/core/src/code_assist
diff options
context:
space:
mode:
authorTommaso Sciortino <[email protected]>2025-06-18 16:34:00 -0700
committerGitHub <[email protected]>2025-06-18 16:34:00 -0700
commit8bc3b415c973794654d64d434949a93fb3239acb (patch)
tree464339fab8b05f754ae158fed50855b2a5f61231 /packages/core/src/code_assist
parentb96fbd913e8449c99ed7b95920652acdce5dd779 (diff)
Refactor in preparation for Reauth (#1196)
Diffstat (limited to 'packages/core/src/code_assist')
-rw-r--r--packages/core/src/code_assist/oauth2.test.ts9
-rw-r--r--packages/core/src/code_assist/oauth2.ts104
2 files changed, 68 insertions, 45 deletions
diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts
index 47bd45b3..0f5b791b 100644
--- a/packages/core/src/code_assist/oauth2.test.ts
+++ b/packages/core/src/code_assist/oauth2.test.ts
@@ -73,8 +73,10 @@ describe('oauth2', () => {
(resolve) => (serverListeningCallback = resolve),
);
+ let capturedPort = 0;
const mockHttpServer = {
listen: vi.fn((port: number, callback?: () => void) => {
+ capturedPort = port;
if (callback) {
callback();
}
@@ -86,7 +88,7 @@ describe('oauth2', () => {
}
}),
on: vi.fn(),
- address: () => ({ port: 1234 }),
+ address: () => ({ port: capturedPort }),
};
vi.mocked(http.createServer).mockImplementation((cb) => {
requestCallback = cb as http.RequestListener<
@@ -115,7 +117,10 @@ describe('oauth2', () => {
expect(client).toBe(mockOAuth2Client);
expect(open).toHaveBeenCalledWith(mockAuthUrl);
- expect(mockGetToken).toHaveBeenCalledWith(mockCode);
+ expect(mockGetToken).toHaveBeenCalledWith({
+ code: mockCode,
+ redirect_uri: `http://localhost:${capturedPort}/oauth2callback`,
+ });
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts
index 9e15f65b..6527f957 100644
--- a/packages/core/src/code_assist/oauth2.ts
+++ b/packages/core/src/code_assist/oauth2.ts
@@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL =
const GEMINI_DIR = '.gemini';
const CREDENTIAL_FILENAME = 'oauth_creds.json';
-export async function getOauthClient(): Promise<OAuth2Client> {
- try {
- return await getCachedCredentialClient();
- } catch (_) {
- const loggedInClient = await webLoginClient();
- await setCachedCredentials(loggedInClient.credentials);
- return loggedInClient;
- }
+/**
+ * An Authentication URL for updating the credentials of a Oauth2Client
+ * as well as a promise that will resolve when the credentials have
+ * been refreshed (or which throws error when refreshing credentials failed).
+ */
+export interface OauthWebLogin {
+ authUrl: string;
+ loginCompletePromise: Promise<void>;
}
-async function webLoginClient(): Promise<OAuth2Client> {
- const port = await getAvailablePort();
- const oAuth2Client = new OAuth2Client({
+export async function getOauthClient(): Promise<OAuth2Client> {
+ const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
- redirectUri: `http://localhost:${port}/oauth2callback`,
});
- return new Promise((resolve, reject) => {
- const state = crypto.randomBytes(32).toString('hex');
- const authURL: string = oAuth2Client.generateAuthUrl({
- access_type: 'offline',
- scope: OAUTH_SCOPE,
- state,
- });
- console.log(
- `\n\nCode Assist login required.\n` +
- `Attempting to open authentication page in your browser.\n` +
- `Otherwise navigate to:\n\n${authURL}\n\n`,
- );
- open(authURL);
- console.log('Waiting for authentication...');
+ if (await loadCachedCredentials(client)) {
+ // Found valid cached credentials.
+ return client;
+ }
+ const webLogin = await authWithWeb(client);
+
+ console.log(
+ `\n\nCode Assist login required.\n` +
+ `Attempting to open authentication page in your browser.\n` +
+ `Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
+ );
+ await open(webLogin.authUrl);
+ console.log('Waiting for authentication...');
+
+ await webLogin.loginCompletePromise;
+
+ return client;
+}
+
+async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
+ const port = await getAvailablePort();
+ const redirectUri = `http://localhost:${port}/oauth2callback`;
+ const state = crypto.randomBytes(32).toString('hex');
+ const authUrl: string = client.generateAuthUrl({
+ redirect_uri: redirectUri,
+ access_type: 'offline',
+ scope: OAUTH_SCOPE,
+ state,
+ });
+
+ const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
@@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> {
reject(new Error('State mismatch. Possible CSRF attack'));
} else if (qs.get('code')) {
- const code: string = qs.get('code')!;
- const { tokens } = await oAuth2Client.getToken(code);
- oAuth2Client.setCredentials(tokens);
+ const { tokens } = await client.getToken({
+ code: qs.get('code')!,
+ redirect_uri: redirectUri,
+ });
+ client.setCredentials(tokens);
+ await cacheCredentials(client.credentials);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
- resolve(oAuth2Client);
+ resolve();
} else {
reject(new Error('No code found in request'));
}
@@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> {
});
server.listen(port);
});
+
+ return {
+ authUrl,
+ loginCompletePromise,
+ };
}
-function getAvailablePort(): Promise<number> {
+export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
@@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> {
});
}
-async function getCachedCredentialClient(): Promise<OAuth2Client> {
+async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
try {
const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8');
- const oAuth2Client = new OAuth2Client({
- clientId: OAUTH_CLIENT_ID,
- clientSecret: OAUTH_CLIENT_SECRET,
- });
- oAuth2Client.setCredentials(JSON.parse(creds));
+ client.setCredentials(JSON.parse(creds));
// This will either return the existing token or refresh it.
- await oAuth2Client.getAccessToken();
- // If we are here, the token is valid.
- return oAuth2Client;
+ await client.getAccessToken();
+
+ return true;
} catch (_) {
- // Could not load credentials.
- throw new Error('Could not load credentials');
+ return false;
}
}
-async function setCachedCredentials(credentials: Credentials) {
+async function cacheCredentials(credentials: Credentials) {
const filePath = getCachedCredentialPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });