20
20
21
21
import java .io .File ;
22
22
import java .io .IOException ;
23
-
24
- import org .slf4j .Logger ;
25
- import org .slf4j .LoggerFactory ;
26
23
import org .apache .commons .io .FileUtils ;
27
24
import org .apache .hadoop .classification .VisibleForTesting ;
25
+ import org .apache .hadoop .conf .Configuration ;
28
26
import org .apache .hadoop .thirdparty .com .google .common .base .Strings ;
29
27
import org .apache .hadoop .util .Preconditions ;
28
+ import org .slf4j .Logger ;
29
+ import org .slf4j .LoggerFactory ;
30
30
31
31
/**
32
32
* Provides tokens based on Azure AD Workload Identity.
@@ -38,11 +38,72 @@ public class WorkloadIdentityTokenProvider extends AccessTokenProvider {
38
38
private static final String EMPTY_TOKEN_FILE_ERROR = "Empty token file found at specified path: " ;
39
39
private static final String TOKEN_FILE_READ_ERROR = "Error reading token file at specified path: " ;
40
40
41
+ /**
42
+ * Internal implementation of ClientAssertionProvider for file-based token reading.
43
+ * This provides backward compatibility for the file-based constructor.
44
+ */
45
+ private static class FileBasedClientAssertionProvider implements ClientAssertionProvider {
46
+ private final String tokenFile ;
47
+
48
+ public FileBasedClientAssertionProvider (String tokenFile ) {
49
+ this .tokenFile = tokenFile ;
50
+ }
51
+
52
+ @ Override
53
+ public void initialize (Configuration configuration , String accountName ) throws IOException {
54
+ // No initialization needed for file-based provider
55
+ }
56
+
57
+ @ Override
58
+ public String getClientAssertion () throws IOException {
59
+ String clientAssertion = "" ;
60
+ try {
61
+ File file = new File (tokenFile );
62
+ clientAssertion = FileUtils .readFileToString (file , "UTF-8" );
63
+ } catch (Exception e ) {
64
+ throw new IOException (TOKEN_FILE_READ_ERROR + tokenFile , e );
65
+ }
66
+ if (Strings .isNullOrEmpty (clientAssertion )) {
67
+ throw new IOException (EMPTY_TOKEN_FILE_ERROR + tokenFile );
68
+ }
69
+ return clientAssertion ;
70
+ }
71
+ }
72
+
41
73
private final String authEndpoint ;
42
74
private final String clientId ;
43
- private final String tokenFile ;
75
+ private final ClientAssertionProvider clientAssertionProvider ;
44
76
private long tokenFetchTime = -1 ;
45
77
78
+ /**
79
+ * Constructor with custom ClientAssertionProvider.
80
+ * Use this for custom token retrieval mechanisms like Kubernetes Token Request API.
81
+ *
82
+ * @param authority OAuth authority URL
83
+ * @param tenantId Azure AD tenant ID
84
+ * @param clientId Azure AD client ID
85
+ * @param clientAssertionProvider Custom provider for client assertions
86
+ */
87
+ public WorkloadIdentityTokenProvider (final String authority , final String tenantId ,
88
+ final String clientId , ClientAssertionProvider clientAssertionProvider ) {
89
+ Preconditions .checkNotNull (authority , "authority" );
90
+ Preconditions .checkNotNull (tenantId , "tenantId" );
91
+ Preconditions .checkNotNull (clientId , "clientId" );
92
+ Preconditions .checkNotNull (clientAssertionProvider , "clientAssertionProvider" );
93
+
94
+ this .authEndpoint = authority + tenantId + OAUTH2_TOKEN_PATH ;
95
+ this .clientId = clientId ;
96
+ this .clientAssertionProvider = clientAssertionProvider ;
97
+ }
98
+
99
+ /**
100
+ * Constructor with file-based token reading (backward compatibility).
101
+ *
102
+ * @param authority OAuth authority URL
103
+ * @param tenantId Azure AD tenant ID
104
+ * @param clientId Azure AD client ID
105
+ * @param tokenFile Path to file containing the JWT token
106
+ */
46
107
public WorkloadIdentityTokenProvider (final String authority , final String tenantId ,
47
108
final String clientId , final String tokenFile ) {
48
109
Preconditions .checkNotNull (authority , "authority" );
@@ -52,13 +113,13 @@ public WorkloadIdentityTokenProvider(final String authority, final String tenant
52
113
53
114
this .authEndpoint = authority + tenantId + OAUTH2_TOKEN_PATH ;
54
115
this .clientId = clientId ;
55
- this .tokenFile = tokenFile ;
116
+ this .clientAssertionProvider = new FileBasedClientAssertionProvider ( tokenFile ) ;
56
117
}
57
118
58
119
@ Override
59
120
protected AzureADToken refreshToken () throws IOException {
60
121
LOG .debug ("AADToken: refreshing token from JWT Assertion" );
61
- String clientAssertion = getClientAssertion ();
122
+ String clientAssertion = clientAssertionProvider . getClientAssertion ();
62
123
AzureADToken token = getTokenUsingJWTAssertion (clientAssertion );
63
124
tokenFetchTime = System .currentTimeMillis ();
64
125
return token ;
@@ -90,31 +151,6 @@ protected boolean isTokenAboutToExpire() {
90
151
return expiring ;
91
152
}
92
153
93
- /**
94
- * Gets the client assertion from the token file.
95
- * The token file should contain the client assertion in JWT format.
96
- * It should be a String containing Base64Url encoded JSON Web Token (JWT).
97
- * See <a href="https://azure.github.io/azure-workload-identity/docs/faq.html#does-workload-identity-work-in-disconnected-environments">
98
- * Azure Workload Identity FAQ</a>.
99
- *
100
- * @return the client assertion.
101
- * @throws IOException if the token file is empty.
102
- */
103
- private String getClientAssertion ()
104
- throws IOException {
105
- String clientAssertion = "" ;
106
- try {
107
- File file = new File (tokenFile );
108
- clientAssertion = FileUtils .readFileToString (file , "UTF-8" );
109
- } catch (Exception e ) {
110
- throw new IOException (TOKEN_FILE_READ_ERROR + tokenFile , e );
111
- }
112
- if (Strings .isNullOrEmpty (clientAssertion )) {
113
- throw new IOException (EMPTY_TOKEN_FILE_ERROR + tokenFile );
114
- }
115
- return clientAssertion ;
116
- }
117
-
118
154
/**
119
155
* Gets the Azure AD token from a client assertion in JWT format.
120
156
* This method exists to make unit testing possible.
0 commit comments