1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { getOauthClient } from './oauth2.js';
import { OAuth2Client } from 'google-auth-library';
import * as fs from 'fs';
import * as path from 'path';
import http from 'http';
import open from 'open';
import crypto from 'crypto';
import * as os from 'os';
vi.mock('os', async (importOriginal) => {
const os = await importOriginal<typeof import('os')>();
return {
...os,
homedir: vi.fn(),
};
});
vi.mock('google-auth-library');
vi.mock('http');
vi.mock('open');
vi.mock('crypto');
describe('oauth2', () => {
let tempHomeDir: string;
beforeEach(() => {
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
);
vi.mocked(os.homedir).mockReturnValue(tempHomeDir);
});
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
});
it('should perform a web login', async () => {
const mockAuthUrl = 'https://example.com/auth';
const mockCode = 'test-code';
const mockState = 'test-state';
const mockTokens = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
};
const mockGenerateAuthUrl = vi.fn().mockReturnValue(mockAuthUrl);
const mockGetToken = vi.fn().mockResolvedValue({ tokens: mockTokens });
const mockSetCredentials = vi.fn();
const mockOAuth2Client = {
generateAuthUrl: mockGenerateAuthUrl,
getToken: mockGetToken,
setCredentials: mockSetCredentials,
credentials: mockTokens,
} as unknown as OAuth2Client;
vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client);
vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never);
vi.mocked(open).mockImplementation(async () => ({}) as never);
let requestCallback!: http.RequestListener<
typeof http.IncomingMessage,
typeof http.ServerResponse
>;
let serverListeningCallback: (value: unknown) => void;
const serverListeningPromise = new Promise(
(resolve) => (serverListeningCallback = resolve),
);
let capturedPort = 0;
const mockHttpServer = {
listen: vi.fn((port: number, callback?: () => void) => {
capturedPort = port;
if (callback) {
callback();
}
serverListeningCallback(undefined);
}),
close: vi.fn((callback?: () => void) => {
if (callback) {
callback();
}
}),
on: vi.fn(),
address: () => ({ port: capturedPort }),
};
vi.mocked(http.createServer).mockImplementation((cb) => {
requestCallback = cb as http.RequestListener<
typeof http.IncomingMessage,
typeof http.ServerResponse
>;
return mockHttpServer as unknown as http.Server;
});
const clientPromise = getOauthClient();
// wait for server to start listening.
await serverListeningPromise;
const mockReq = {
url: `/oauth2callback?code=${mockCode}&state=${mockState}`,
} as http.IncomingMessage;
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
} as unknown as http.ServerResponse;
await requestCallback(mockReq, mockRes);
const client = await clientPromise;
expect(client).toBe(mockOAuth2Client);
expect(open).toHaveBeenCalledWith(mockAuthUrl);
expect(mockGetToken).toHaveBeenCalledWith({
code: mockCode,
redirect_uri: `http://localhost:${capturedPort}/oauth2callback`,
});
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8'));
expect(tokenData).toEqual(mockTokens);
});
});
|