diff options
| author | Brian Ray <[email protected]> | 2025-07-18 10:14:23 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-18 14:14:23 +0000 |
| commit | c5761317f4918545d8c5598c5d6204ded534f51e (patch) | |
| tree | 30300515c97236219eba92d9c3268a231b824ae4 /packages/core/src/mcp/oauth-provider.test.ts | |
| parent | de27ea6095f3f36f8939af0aaeaf9362e3c61490 (diff) | |
MCP OAuth Part 1 - OAuth Infrastructure (#4316)
Diffstat (limited to 'packages/core/src/mcp/oauth-provider.test.ts')
| -rw-r--r-- | packages/core/src/mcp/oauth-provider.test.ts | 720 |
1 files changed, 720 insertions, 0 deletions
diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts new file mode 100644 index 00000000..41938969 --- /dev/null +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -0,0 +1,720 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import * as http from 'node:http'; +import * as crypto from 'node:crypto'; +import open from 'open'; +import { + MCPOAuthProvider, + MCPOAuthConfig, + OAuthTokenResponse, + OAuthClientRegistrationResponse, +} from './oauth-provider.js'; +import { MCPOAuthTokenStorage, MCPOAuthToken } from './oauth-token-storage.js'; + +// Mock dependencies +vi.mock('open'); +vi.mock('node:crypto'); +vi.mock('./oauth-token-storage.js'); + +// Mock fetch globally +const mockFetch = vi.fn(); +global.fetch = mockFetch; + +// Define a reusable mock server with .listen, .close, and .on methods +const mockHttpServer = { + listen: vi.fn(), + close: vi.fn(), + on: vi.fn(), +}; +vi.mock('node:http', () => ({ + createServer: vi.fn(() => mockHttpServer), +})); + +describe('MCPOAuthProvider', () => { + const mockConfig: MCPOAuthConfig = { + enabled: true, + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + authorizationUrl: 'https://auth.example.com/authorize', + tokenUrl: 'https://auth.example.com/token', + scopes: ['read', 'write'], + redirectUri: 'http://localhost:7777/oauth/callback', + }; + + const mockToken: MCPOAuthToken = { + accessToken: 'access_token_123', + refreshToken: 'refresh_token_456', + tokenType: 'Bearer', + scope: 'read write', + expiresAt: Date.now() + 3600000, + }; + + const mockTokenResponse: OAuthTokenResponse = { + access_token: 'access_token_123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh_token_456', + scope: 'read write', + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(console, 'log').mockImplementation(() => {}); + vi.spyOn(console, 'warn').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Mock crypto functions + vi.mocked(crypto.randomBytes).mockImplementation((size: number) => { + if (size === 32) return Buffer.from('code_verifier_mock_32_bytes_long'); + if (size === 16) return Buffer.from('state_mock_16_by'); + return Buffer.alloc(size); + }); + + vi.mocked(crypto.createHash).mockReturnValue({ + update: vi.fn().mockReturnThis(), + digest: vi.fn().mockReturnValue('code_challenge_mock'), + } as unknown as crypto.Hash); + + // Mock randomBytes to return predictable values for state + vi.mocked(crypto.randomBytes).mockImplementation((size) => { + if (size === 32) { + return Buffer.from('mock_code_verifier_32_bytes_long_string'); + } else if (size === 16) { + return Buffer.from('mock_state_16_bytes'); + } + return Buffer.alloc(size); + }); + + // Mock token storage + vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined); + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('authenticate', () => { + it('should perform complete OAuth flow with PKCE', async () => { + // Mock HTTP server callback + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + // Simulate OAuth callback + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + mockConfig, + ); + + expect(result).toEqual({ + accessToken: 'access_token_123', + refreshToken: 'refresh_token_456', + tokenType: 'Bearer', + scope: 'read write', + expiresAt: expect.any(Number), + }); + + expect(open).toHaveBeenCalledWith(expect.stringContaining('authorize')); + expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.objectContaining({ accessToken: 'access_token_123' }), + 'test-client-id', + 'https://auth.example.com/token', + ); + }); + + it('should handle OAuth discovery when no authorization URL provided', async () => { + // Use a mutable config object + const configWithoutAuth: MCPOAuthConfig = { ...mockConfig }; + delete configWithoutAuth.authorizationUrl; + delete configWithoutAuth.tokenUrl; + + const mockResourceMetadata = { + authorization_servers: ['https://discovered.auth.com'], + }; + + const mockAuthServerMetadata = { + authorization_endpoint: 'https://discovered.auth.com/authorize', + token_endpoint: 'https://discovered.auth.com/token', + scopes_supported: ['read', 'write'], + }; + + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockResourceMetadata), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockAuthServerMetadata), + }); + + // Patch config after discovery + configWithoutAuth.authorizationUrl = + mockAuthServerMetadata.authorization_endpoint; + configWithoutAuth.tokenUrl = mockAuthServerMetadata.token_endpoint; + configWithoutAuth.scopes = mockAuthServerMetadata.scopes_supported; + + // Setup callback handler + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange with discovered endpoint + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + configWithoutAuth, + 'https://api.example.com', + ); + + expect(result).toBeDefined(); + expect(mockFetch).toHaveBeenCalledWith( + 'https://discovered.auth.com/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + }), + ); + }); + + it('should perform dynamic client registration when no client ID provided', async () => { + const configWithoutClient = { ...mockConfig }; + delete configWithoutClient.clientId; + + const mockRegistrationResponse: OAuthClientRegistrationResponse = { + client_id: 'dynamic_client_id', + client_secret: 'dynamic_client_secret', + redirect_uris: ['http://localhost:7777/oauth/callback'], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + + const mockAuthServerMetadata = { + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + }; + + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockAuthServerMetadata), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockRegistrationResponse), + }); + + // Setup callback handler + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + configWithoutClient, + ); + + expect(result).toBeDefined(); + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.example.com/register', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + }), + ); + }); + + it('should handle OAuth callback errors', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?error=access_denied&error_description=User%20denied%20access', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('OAuth error: access_denied'); + }); + + it('should handle state mismatch in callback', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=wrong_state', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('State mismatch - possible CSRF attack'); + }); + + it('should handle token exchange failure', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid grant'), + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('Token exchange failed: 400 - Invalid grant'); + }); + + it('should handle callback timeout', async () => { + vi.mocked(http.createServer).mockImplementation( + () => mockHttpServer as unknown as http.Server, + ); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + // Don't trigger callback - simulate timeout + }); + + // Mock setTimeout to trigger timeout immediately + const originalSetTimeout = global.setTimeout; + global.setTimeout = vi.fn((callback, delay) => { + if (delay === 5 * 60 * 1000) { + // 5 minute timeout + callback(); + } + return originalSetTimeout(callback, 0); + }) as unknown as typeof setTimeout; + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('OAuth callback timeout'); + + global.setTimeout = originalSetTimeout; + }); + }); + + describe('refreshAccessToken', () => { + it('should refresh token successfully', async () => { + const refreshResponse = { + access_token: 'new_access_token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new_refresh_token', + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(refreshResponse), + }); + + const result = await MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'old_refresh_token', + 'https://auth.example.com/token', + ); + + expect(result).toEqual(refreshResponse); + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: expect.stringContaining('grant_type=refresh_token'), + }), + ); + }); + + it('should include client secret in refresh request when available', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'refresh_token', + 'https://auth.example.com/token', + ); + + const fetchCall = mockFetch.mock.calls[0]; + expect(fetchCall[1].body).toContain('client_secret=test-client-secret'); + }); + + it('should handle refresh token failure', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid refresh token'), + }); + + await expect( + MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'invalid_refresh_token', + 'https://auth.example.com/token', + ), + ).rejects.toThrow('Token refresh failed: 400 - Invalid refresh token'); + }); + }); + + describe('getValidToken', () => { + it('should return valid token when not expired', async () => { + const validCredentials = { + serverName: 'test-server', + token: mockToken, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + validCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBe('access_token_123'); + }); + + it('should refresh expired token and return new token', async () => { + const expiredCredentials = { + serverName: 'test-server', + token: { ...mockToken, expiresAt: Date.now() - 3600000 }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + expiredCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + + const refreshResponse = { + access_token: 'new_access_token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new_refresh_token', + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(refreshResponse), + }); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBe('new_access_token'); + expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.objectContaining({ accessToken: 'new_access_token' }), + 'test-client-id', + 'https://auth.example.com/token', + ); + }); + + it('should return null when no credentials exist', async () => { + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + }); + + it('should handle refresh failure and remove invalid token', async () => { + const expiredCredentials = { + serverName: 'test-server', + token: { ...mockToken, expiresAt: Date.now() - 3600000 }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + expiredCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid refresh token'), + }); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith( + 'test-server', + ); + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to refresh token'), + ); + }); + + it('should return null for token without refresh capability', async () => { + const tokenWithoutRefresh = { + serverName: 'test-server', + token: { + ...mockToken, + refreshToken: undefined, + expiresAt: Date.now() - 3600000, + }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + tokenWithoutRefresh, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + }); + }); + + describe('PKCE parameter generation', () => { + it('should generate valid PKCE parameters', async () => { + // Test is implicit in the authenticate flow tests, but we can verify + // the crypto mocks are called correctly + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.authenticate('test-server', mockConfig); + + expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier + expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state + expect(crypto.createHash).toHaveBeenCalledWith('sha256'); + }); + }); + + describe('Authorization URL building', () => { + it('should build correct authorization URL with all parameters', async () => { + // Mock to capture the URL that would be opened + let capturedUrl: string; + vi.mocked(open).mockImplementation((url) => { + capturedUrl = url; + // Return a minimal mock ChildProcess + return Promise.resolve({ + pid: 1234, + } as unknown as import('child_process').ChildProcess); + }); + + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.authenticate('test-server', mockConfig); + + expect(capturedUrl!).toContain('response_type=code'); + expect(capturedUrl!).toContain('client_id=test-client-id'); + expect(capturedUrl!).toContain('code_challenge=code_challenge_mock'); + expect(capturedUrl!).toContain('code_challenge_method=S256'); + expect(capturedUrl!).toContain('scope=read+write'); + expect(capturedUrl!).toContain('resource=https%3A%2F%2Fauth.example.com'); + }); + }); +}); |
