diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index 2050fe5af..1d8b48dac 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -58,6 +58,13 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { // Use the utility class to resolve client ID and client secret String clientId = OAuthClientUtils.resolveClientId(config); String clientSecret = OAuthClientUtils.resolveClientSecret(config); + OpenIDConnectEndpoints oidcEndpoints = null; + try { + oidcEndpoints = OAuthClientUtils.resolveOidcEndpoints(config); + } catch (IOException e) { + LOGGER.error("Failed to resolve OIDC endpoints: {}", e.getMessage()); + return null; + } try { if (tokenCache == null) { @@ -78,7 +85,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { new SessionCredentialsTokenSource( cachedToken, config.getHttpClient(), - config.getOidcEndpoints().getTokenEndpoint(), + oidcEndpoints.getTokenEndpoint(), clientId, clientSecret, Optional.of(config.getEffectiveOAuthRedirectUrl()), @@ -100,7 +107,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { // If no cached token or refresh failed, perform browser auth CachedTokenSource cachedTokenSource = - performBrowserAuth(config, clientId, clientSecret, tokenCache); + performBrowserAuth(config, clientId, clientSecret, tokenCache, oidcEndpoints); tokenCache.save(cachedTokenSource.getToken()); return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (IOException | DatabricksException e) { @@ -109,7 +116,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } } - protected List getScopes(DatabricksConfig config) { + protected List getScopes(DatabricksConfig config, OpenIDConnectEndpoints oidcEndpoints) { // Get user-provided scopes and add required default scopes. Set scopes = new HashSet<>(config.getScopes()); // Requesting a refresh token is most of the time the right thing to do from a @@ -125,7 +132,11 @@ protected List getScopes(DatabricksConfig config) { } CachedTokenSource performBrowserAuth( - DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache) + DatabricksConfig config, + String clientId, + String clientSecret, + TokenCache tokenCache, + OpenIDConnectEndpoints oidcEndpoints) throws IOException { LOGGER.debug("Performing browser authentication"); @@ -138,8 +149,8 @@ CachedTokenSource performBrowserAuth( .withAccountId(config.getAccountId()) .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) .withBrowserTimeout(config.getOAuthBrowserAuthTimeout()) - .withScopes(getScopes(config)) - .withOpenIDConnectEndpoints(config.getOidcEndpoints()) + .withScopes(getScopes(config, oidcEndpoints)) + .withOpenIDConnectEndpoints(oidcEndpoints) .build(); Consent consent = client.initiateConsent(); @@ -151,7 +162,7 @@ CachedTokenSource performBrowserAuth( new SessionCredentialsTokenSource( token, config.getHttpClient(), - config.getOidcEndpoints().getTokenEndpoint(), + oidcEndpoints.getTokenEndpoint(), config.getClientId(), config.getClientSecret(), Optional.ofNullable(config.getEffectiveOAuthRedirectUrl()), diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java index 5908eff79..2af37d96e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java @@ -1,6 +1,9 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; /** Utility methods for OAuth client credentials resolution. */ public class OAuthClientUtils { @@ -39,4 +42,33 @@ public static String resolveClientSecret(DatabricksConfig config) { } return null; } + + /** + * Resolves the OAuth OIDC endpoints from the configuration. Prioritizes regular OIDC endpoints, + * then Azure OIDC endpoints. If the client ID and client secret are provided, the OIDC endpoints + * are fetched from the discovery URL. If the Azure client ID and client secret are provided, the + * OIDC endpoints are fetched from the Azure AD endpoint. If no client ID and client secret are + * provided, the OIDC endpoints are fetched from the default OIDC endpoints. + * + * @param config The Databricks configuration + * @return The resolved OIDC endpoints + * @throws IOException if the OIDC endpoints cannot be resolved + */ + public static OpenIDConnectEndpoints resolveOidcEndpoints(DatabricksConfig config) + throws IOException { + if (config.getClientId() != null && config.getClientSecret() != null) { + return config.getOidcEndpoints(); + } else if (config.getAzureClientId() != null && config.getAzureClientSecret() != null) { + Request request = new Request("GET", config.getHost() + "/oidc/oauth2/v2.0/authorize"); + request.setRedirectionBehavior(false); + Response resp = config.getHttpClient().execute(request); + String realAuthUrl = resp.getFirstHeader("location"); + if (realAuthUrl == null) { + return null; + } + return new OpenIDConnectEndpoints( + realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl); + } + return config.getOidcEndpoints(); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index dbd1cba9f..14f3bb25d 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -303,7 +303,8 @@ void cacheWithValidRefreshableTokenTest() throws IOException { any(DatabricksConfig.class), any(String.class), any(String.class), - any(TokenCache.class)); + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Verify token was NOT saved back to cache (we're using the cached one as-is). Mockito.verify(mockTokenCache, Mockito.never()).save(any(Token.class)); @@ -363,7 +364,12 @@ void cacheWithValidNonRefreshableTokenTest() throws IOException { // Verify performBrowserAuth was NOT called. Mockito.verify(provider, Mockito.never()) - .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Verify no token was saved (we're using the cached one as-is). Mockito.verify(mockTokenCache, Mockito.never()).save(any(Token.class)); @@ -430,7 +436,8 @@ void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException { any(DatabricksConfig.class), any(String.class), any(String.class), - any(TokenCache.class)); + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Verify token was saved back to cache Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); @@ -508,7 +515,12 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); Mockito.doReturn(cachedTokenSource) .when(provider) - .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Spy on the config to inject the endpoints DatabricksConfig spyConfig = Mockito.spy(config); @@ -527,7 +539,12 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { // Verify performBrowserAuth was called since refresh failed Mockito.verify(provider, Mockito.times(1)) - .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Verify token was saved after browser auth (for the new token) Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); @@ -572,17 +589,31 @@ void cacheWithInvalidTokensTest() throws IOException { new DatabricksConfig() .setAuthType("external-browser") .setHost("https://test.databricks.com") - .setClientId("test-client-id"); + .setClientId("test-client-id") + .setHttpClient(mockHttpClient); // Create our provider and mock the browser auth method ExternalBrowserCredentialsProvider provider = Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); Mockito.doReturn(cachedTokenSource) .when(provider) - .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); + + // Spy on the config to inject the endpoints + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints( + "https://test.databricks.com/oidc/v1/token", + "https://test.databricks.com/oidc/v1/authorize"); + DatabricksConfig spyConfig = Mockito.spy(config); + Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints(); // Configure provider - HeaderFactory headerFactory = provider.configure(config); + HeaderFactory headerFactory = provider.configure(spyConfig); assertNotNull(headerFactory); // Verify headers contain the browser auth token (fallback) Map headers = headerFactory.headers(); @@ -593,7 +624,12 @@ void cacheWithInvalidTokensTest() throws IOException { // Verify performBrowserAuth was called since we had an invalid token Mockito.verify(provider, Mockito.times(1)) - .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); // Verify token was saved after browser auth (for the new token) Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); @@ -609,7 +645,7 @@ void doNotAddOfflineAccessScopeWhenDisableOauthRefreshTokenIsTrue() { .setScopes(Arrays.asList("my-test-scope")); ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider(); - List scopes = provider.getScopes(config); + List scopes = provider.getScopes(config, null); assertEquals(1, scopes.size()); assertTrue(scopes.contains("my-test-scope")); @@ -625,7 +661,7 @@ void doNotRemoveUserProvidedScopesWhenDisableOauthRefreshTokenIsTrue() { .setScopes(Arrays.asList("my-test-scope", "offline_access")); ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider(); - List scopes = provider.getScopes(config); + List scopes = provider.getScopes(config, null); assertEquals(2, scopes.size()); assertTrue(scopes.contains("offline_access")); @@ -641,10 +677,98 @@ void addOfflineAccessScopeWhenDisableOauthRefreshTokenIsFalse() { .setScopes(Arrays.asList("my-test-scope")); ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider(); - List scopes = provider.getScopes(config); + List scopes = provider.getScopes(config, null); assertEquals(2, scopes.size()); assertTrue(scopes.contains("offline_access")); assertTrue(scopes.contains("my-test-scope")); } + + @Test + void externalBrowserAuthWithAzureClientIdTest() throws IOException { + // Create mock HTTP client + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + + // Mock token cache + TokenCache mockTokenCache = Mockito.mock(TokenCache.class); + Mockito.doReturn(null).when(mockTokenCache).load(); + + // Create valid token for browser auth + Token browserAuthToken = + new Token( + "azure_access_token", "Bearer", "azure_refresh_token", Instant.now().plusSeconds(3600)); + + // Create token source + SessionCredentialsTokenSource browserAuthTokenSource = + new SessionCredentialsTokenSource( + browserAuthToken, + mockHttpClient, + "https://test.azuredatabricks.net/oidc/v1/token", + "test-azure-client-id", + null, + Optional.empty(), + Optional.empty()); + + CachedTokenSource cachedTokenSource = + new CachedTokenSource.Builder(browserAuthTokenSource).setToken(browserAuthToken).build(); + + // Create Azure config with Azure client ID + DatabricksConfig config = + new DatabricksConfig() + .setAuthType("external-browser") + .setHost("https://test.azuredatabricks.net") + .setAzureClientId("test-azure-client-id") + .setHttpClient(mockHttpClient); + + // Create provider and mock browser auth + ExternalBrowserCredentialsProvider provider = + Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); + Mockito.doReturn(cachedTokenSource) + .when(provider) + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + any(OpenIDConnectEndpoints.class)); + + // Spy on config to inject OIDC endpoints + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints( + "https://test.azuredatabricks.net/oidc/v1/token", + "https://test.azuredatabricks.net/oidc/v1/authorize"); + DatabricksConfig spyConfig = Mockito.spy(config); + Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints(); + + // Configure provider + HeaderFactory headerFactory = provider.configure(spyConfig); + assertNotNull(headerFactory); + + // Verify headers contain the Azure token + Map headers = headerFactory.headers(); + assertEquals("Bearer azure_access_token", headers.get("Authorization")); + + // Capture and verify the OpenIDConnectEndpoints passed to performBrowserAuth + ArgumentCaptor endpointsCaptor = + ArgumentCaptor.forClass(OpenIDConnectEndpoints.class); + Mockito.verify(provider, Mockito.times(1)) + .performBrowserAuth( + any(DatabricksConfig.class), + any(), + any(), + any(TokenCache.class), + endpointsCaptor.capture()); + + // Verify the captured endpoints match what we expect for Azure + OpenIDConnectEndpoints capturedEndpoints = endpointsCaptor.getValue(); + assertNotNull(capturedEndpoints); + assertEquals( + "https://test.azuredatabricks.net/oidc/v1/token", capturedEndpoints.getTokenEndpoint()); + assertEquals( + "https://test.azuredatabricks.net/oidc/v1/authorize", + capturedEndpoints.getAuthorizationEndpoint()); + + // Verify token was saved + Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); + } }