Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
// Use the utility class to resolve client ID and client secret
String clientId = OAuthClientUtils.resolveClientId(config);
String clientSecret = OAuthClientUtils.resolveClientSecret(config);
OpenIDConnectEndpoints oidcEndpoints = null;
try {
oidcEndpoints = OAuthClientUtils.resolveOidcEndpoints(config);
} catch (IOException e) {
LOGGER.error("Failed to resolve OIDC endpoints: {}", e.getMessage());
return null;
}

try {
if (tokenCache == null) {
Expand All @@ -78,7 +85,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
new SessionCredentialsTokenSource(
cachedToken,
config.getHttpClient(),
config.getOidcEndpoints().getTokenEndpoint(),
oidcEndpoints.getTokenEndpoint(),
clientId,
clientSecret,
Optional.of(config.getEffectiveOAuthRedirectUrl()),
Expand All @@ -100,7 +107,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {

// If no cached token or refresh failed, perform browser auth
CachedTokenSource cachedTokenSource =
performBrowserAuth(config, clientId, clientSecret, tokenCache);
performBrowserAuth(config, clientId, clientSecret, tokenCache, oidcEndpoints);
tokenCache.save(cachedTokenSource.getToken());
return OAuthHeaderFactory.fromTokenSource(cachedTokenSource);
} catch (IOException | DatabricksException e) {
Expand All @@ -109,7 +116,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
}
}

protected List<String> getScopes(DatabricksConfig config) {
protected List<String> getScopes(DatabricksConfig config, OpenIDConnectEndpoints oidcEndpoints) {
// Get user-provided scopes and add required default scopes.
Set<String> scopes = new HashSet<>(config.getScopes());
// Requesting a refresh token is most of the time the right thing to do from a
Expand All @@ -125,7 +132,11 @@ protected List<String> getScopes(DatabricksConfig config) {
}

CachedTokenSource performBrowserAuth(
DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache)
DatabricksConfig config,
String clientId,
String clientSecret,
TokenCache tokenCache,
OpenIDConnectEndpoints oidcEndpoints)
throws IOException {
LOGGER.debug("Performing browser authentication");

Expand All @@ -138,8 +149,8 @@ CachedTokenSource performBrowserAuth(
.withAccountId(config.getAccountId())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withBrowserTimeout(config.getOAuthBrowserAuthTimeout())
.withScopes(getScopes(config))
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.withScopes(getScopes(config, oidcEndpoints))
.withOpenIDConnectEndpoints(oidcEndpoints)
.build();
Consent consent = client.initiateConsent();

Expand All @@ -151,7 +162,7 @@ CachedTokenSource performBrowserAuth(
new SessionCredentialsTokenSource(
token,
config.getHttpClient(),
config.getOidcEndpoints().getTokenEndpoint(),
oidcEndpoints.getTokenEndpoint(),
config.getClientId(),
config.getClientSecret(),
Optional.ofNullable(config.getEffectiveOAuthRedirectUrl()),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.databricks.sdk.core.oauth;

import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.http.Request;
import com.databricks.sdk.core.http.Response;
import java.io.IOException;

/** Utility methods for OAuth client credentials resolution. */
public class OAuthClientUtils {
Expand Down Expand Up @@ -39,4 +42,33 @@ public static String resolveClientSecret(DatabricksConfig config) {
}
return null;
}

/**
* Resolves the OAuth OIDC endpoints from the configuration. Prioritizes regular OIDC endpoints,
* then Azure OIDC endpoints. If the client ID and client secret are provided, the OIDC endpoints
* are fetched from the discovery URL. If the Azure client ID and client secret are provided, the
* OIDC endpoints are fetched from the Azure AD endpoint. If no client ID and client secret are
* provided, the OIDC endpoints are fetched from the default OIDC endpoints.
*
* @param config The Databricks configuration
* @return The resolved OIDC endpoints
* @throws IOException if the OIDC endpoints cannot be resolved
*/
public static OpenIDConnectEndpoints resolveOidcEndpoints(DatabricksConfig config)
throws IOException {
if (config.getClientId() != null && config.getClientSecret() != null) {
return config.getOidcEndpoints();
} else if (config.getAzureClientId() != null && config.getAzureClientSecret() != null) {
Request request = new Request("GET", config.getHost() + "/oidc/oauth2/v2.0/authorize");
request.setRedirectionBehavior(false);
Response resp = config.getHttpClient().execute(request);
String realAuthUrl = resp.getFirstHeader("location");
if (realAuthUrl == null) {
return null;
}
return new OpenIDConnectEndpoints(
realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl);
}
return config.getOidcEndpoints();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ void cacheWithValidRefreshableTokenTest() throws IOException {
any(DatabricksConfig.class),
any(String.class),
any(String.class),
any(TokenCache.class));
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Verify token was NOT saved back to cache (we're using the cached one as-is).
Mockito.verify(mockTokenCache, Mockito.never()).save(any(Token.class));
Expand Down Expand Up @@ -363,7 +364,12 @@ void cacheWithValidNonRefreshableTokenTest() throws IOException {

// Verify performBrowserAuth was NOT called.
Mockito.verify(provider, Mockito.never())
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Verify no token was saved (we're using the cached one as-is).
Mockito.verify(mockTokenCache, Mockito.never()).save(any(Token.class));
Expand Down Expand Up @@ -430,7 +436,8 @@ void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException {
any(DatabricksConfig.class),
any(String.class),
any(String.class),
any(TokenCache.class));
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Verify token was saved back to cache
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
Expand Down Expand Up @@ -508,7 +515,12 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException {
Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache));
Mockito.doReturn(cachedTokenSource)
.when(provider)
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Spy on the config to inject the endpoints
DatabricksConfig spyConfig = Mockito.spy(config);
Expand All @@ -527,7 +539,12 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException {

// Verify performBrowserAuth was called since refresh failed
Mockito.verify(provider, Mockito.times(1))
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Verify token was saved after browser auth (for the new token)
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
Expand Down Expand Up @@ -572,17 +589,31 @@ void cacheWithInvalidTokensTest() throws IOException {
new DatabricksConfig()
.setAuthType("external-browser")
.setHost("https://test.databricks.com")
.setClientId("test-client-id");
.setClientId("test-client-id")
.setHttpClient(mockHttpClient);

// Create our provider and mock the browser auth method
ExternalBrowserCredentialsProvider provider =
Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache));
Mockito.doReturn(cachedTokenSource)
.when(provider)
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Spy on the config to inject the endpoints
OpenIDConnectEndpoints endpoints =
new OpenIDConnectEndpoints(
"https://test.databricks.com/oidc/v1/token",
"https://test.databricks.com/oidc/v1/authorize");
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(config);
HeaderFactory headerFactory = provider.configure(spyConfig);
assertNotNull(headerFactory);
// Verify headers contain the browser auth token (fallback)
Map<String, String> headers = headerFactory.headers();
Expand All @@ -593,7 +624,12 @@ void cacheWithInvalidTokensTest() throws IOException {

// Verify performBrowserAuth was called since we had an invalid token
Mockito.verify(provider, Mockito.times(1))
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Verify token was saved after browser auth (for the new token)
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
Expand All @@ -609,7 +645,7 @@ void doNotAddOfflineAccessScopeWhenDisableOauthRefreshTokenIsTrue() {
.setScopes(Arrays.asList("my-test-scope"));

ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider();
List<String> scopes = provider.getScopes(config);
List<String> scopes = provider.getScopes(config, null);

assertEquals(1, scopes.size());
assertTrue(scopes.contains("my-test-scope"));
Expand All @@ -625,7 +661,7 @@ void doNotRemoveUserProvidedScopesWhenDisableOauthRefreshTokenIsTrue() {
.setScopes(Arrays.asList("my-test-scope", "offline_access"));

ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider();
List<String> scopes = provider.getScopes(config);
List<String> scopes = provider.getScopes(config, null);

assertEquals(2, scopes.size());
assertTrue(scopes.contains("offline_access"));
Expand All @@ -641,10 +677,98 @@ void addOfflineAccessScopeWhenDisableOauthRefreshTokenIsFalse() {
.setScopes(Arrays.asList("my-test-scope"));

ExternalBrowserCredentialsProvider provider = new ExternalBrowserCredentialsProvider();
List<String> scopes = provider.getScopes(config);
List<String> scopes = provider.getScopes(config, null);

assertEquals(2, scopes.size());
assertTrue(scopes.contains("offline_access"));
assertTrue(scopes.contains("my-test-scope"));
}

@Test
void externalBrowserAuthWithAzureClientIdTest() throws IOException {
// Create mock HTTP client
HttpClient mockHttpClient = Mockito.mock(HttpClient.class);

// Mock token cache
TokenCache mockTokenCache = Mockito.mock(TokenCache.class);
Mockito.doReturn(null).when(mockTokenCache).load();

// Create valid token for browser auth
Token browserAuthToken =
new Token(
"azure_access_token", "Bearer", "azure_refresh_token", Instant.now().plusSeconds(3600));

// Create token source
SessionCredentialsTokenSource browserAuthTokenSource =
new SessionCredentialsTokenSource(
browserAuthToken,
mockHttpClient,
"https://test.azuredatabricks.net/oidc/v1/token",
"test-azure-client-id",
null,
Optional.empty(),
Optional.empty());

CachedTokenSource cachedTokenSource =
new CachedTokenSource.Builder(browserAuthTokenSource).setToken(browserAuthToken).build();

// Create Azure config with Azure client ID
DatabricksConfig config =
new DatabricksConfig()
.setAuthType("external-browser")
.setHost("https://test.azuredatabricks.net")
.setAzureClientId("test-azure-client-id")
.setHttpClient(mockHttpClient);

// Create provider and mock browser auth
ExternalBrowserCredentialsProvider provider =
Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache));
Mockito.doReturn(cachedTokenSource)
.when(provider)
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
any(OpenIDConnectEndpoints.class));

// Spy on config to inject OIDC endpoints
OpenIDConnectEndpoints endpoints =
new OpenIDConnectEndpoints(
"https://test.azuredatabricks.net/oidc/v1/token",
"https://test.azuredatabricks.net/oidc/v1/authorize");
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(spyConfig);
assertNotNull(headerFactory);

// Verify headers contain the Azure token
Map<String, String> headers = headerFactory.headers();
assertEquals("Bearer azure_access_token", headers.get("Authorization"));

// Capture and verify the OpenIDConnectEndpoints passed to performBrowserAuth
ArgumentCaptor<OpenIDConnectEndpoints> endpointsCaptor =
ArgumentCaptor.forClass(OpenIDConnectEndpoints.class);
Mockito.verify(provider, Mockito.times(1))
.performBrowserAuth(
any(DatabricksConfig.class),
any(),
any(),
any(TokenCache.class),
endpointsCaptor.capture());

// Verify the captured endpoints match what we expect for Azure
OpenIDConnectEndpoints capturedEndpoints = endpointsCaptor.getValue();
assertNotNull(capturedEndpoints);
assertEquals(
"https://test.azuredatabricks.net/oidc/v1/token", capturedEndpoints.getTokenEndpoint());
assertEquals(
"https://test.azuredatabricks.net/oidc/v1/authorize",
capturedEndpoints.getAuthorizationEndpoint());

// Verify token was saved
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
}
}
Loading