Skip to content

Commit 9883b4d

Browse files
jordeupditommaso
andauthored
Fix JWT refresh and use TowerClient to handle Fusion token requests(#6207)
Signed-off-by: Jordi Deu-Pons <[email protected]> Signed-off-by: Paolo Di Tommaso <[email protected]> Co-authored-by: Paolo Di Tommaso <[email protected]>
1 parent c4fa744 commit 9883b4d

File tree

4 files changed

+134
-394
lines changed

4 files changed

+134
-394
lines changed

plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFactory.groovy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,13 @@ class TowerFactory implements TraceObserverFactory {
100100
@Memoized
101101
static TowerClient client(Session session, Map<String,String> env) {
102102
final config = session.config
103-
Boolean isEnabled = config.navigate('tower.enabled') as Boolean || env.get('TOWER_WORKFLOW_ID')
103+
Boolean isEnabled = config.navigate('tower.enabled') as Boolean || env.get('TOWER_WORKFLOW_ID') || config.navigate('fusion.enabled')
104104
return isEnabled
105105
? createTowerClient0(session, config, env)
106106
: null
107107
}
108108

109109
static TowerClient client() {
110-
client(Global.session as Session, SysEnv.get())
110+
return client(Global.session as Session, SysEnv.get())
111111
}
112112
}
Lines changed: 19 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,15 @@
11
package io.seqera.tower.plugin
22

3-
import java.net.http.HttpClient
4-
import java.net.http.HttpRequest
5-
import java.net.http.HttpResponse
6-
import java.time.Duration
7-
import java.time.Instant
8-
import java.time.temporal.ChronoUnit
9-
import java.util.concurrent.Executors
10-
import java.util.function.Predicate
11-
123
import com.google.common.cache.Cache
134
import com.google.common.cache.CacheBuilder
145
import com.google.common.util.concurrent.UncheckedExecutionException
15-
import com.google.gson.Gson
166
import com.google.gson.JsonSyntaxException
17-
import dev.failsafe.Failsafe
18-
import dev.failsafe.RetryPolicy
19-
import dev.failsafe.event.EventListener
20-
import dev.failsafe.event.ExecutionAttemptedEvent
21-
import dev.failsafe.function.CheckedSupplier
227
import groovy.transform.CompileStatic
238
import groovy.util.logging.Slf4j
249
import io.seqera.tower.plugin.exception.BadResponseException
2510
import io.seqera.tower.plugin.exception.UnauthorizedException
2611
import io.seqera.tower.plugin.exchange.GetLicenseTokenRequest
2712
import io.seqera.tower.plugin.exchange.GetLicenseTokenResponse
28-
import io.seqera.util.trace.TraceUtils
2913
import nextflow.SysEnv
3014
import nextflow.exception.AbortOperationException
3115
import nextflow.exception.ReportWarningException
@@ -34,8 +18,12 @@ import nextflow.fusion.FusionToken
3418
import nextflow.platform.PlatformHelper
3519
import nextflow.plugin.Priority
3620
import nextflow.util.GsonHelper
37-
import nextflow.util.Threads
3821
import org.pf4j.Extension
22+
23+
import java.time.Duration
24+
import java.time.Instant
25+
import java.time.temporal.ChronoUnit
26+
3927
/**
4028
* Environment provider for Platform-specific environment variables.
4129
*
@@ -50,24 +38,6 @@ class TowerFusionToken implements FusionToken {
5038
// The path relative to the Platform endpoint where license-scoped JWT tokens are obtained
5139
private static final String LICENSE_TOKEN_PATH = 'license/token/'
5240

53-
// Server errors that should trigger a retry
54-
private static final List<Integer> SERVER_ERRORS = [408, 429, 500, 502, 503, 504]
55-
56-
// Default connection timeout for HTTP requests
57-
private static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.of(30, ChronoUnit.SECONDS)
58-
59-
// Default retry policy settings for HTTP requests: delay, max delay, attempts, and jitter
60-
private static final Duration DEFAULT_RETRY_POLICY_DELAY = Duration.of(450, ChronoUnit.MILLIS)
61-
private static final Duration DEFAULT_RETRY_POLICY_MAX_DELAY = Duration.of(90, ChronoUnit.SECONDS)
62-
private static final int DEFAULT_RETRY_POLICY_MAX_ATTEMPTS = 10
63-
private static final double DEFAULT_RETRY_POLICY_JITTER = 0.5
64-
65-
// The HttpClient instance used to send requests
66-
private final HttpClient httpClient = newDefaultHttpClient()
67-
68-
// The RetryPolicy instance used to retry requests
69-
private final RetryPolicy retryPolicy = newDefaultRetryPolicy(SERVER_ERRORS)
70-
7141
// Time-to-live for cached tokens
7242
private Duration tokenTTL = Duration.of(1, ChronoUnit.HOURS)
7343

@@ -76,32 +46,21 @@ class TowerFusionToken implements FusionToken {
7646
.expireAfterWrite(tokenTTL)
7747
.build()
7848

79-
// Platform endpoint to use for requests
80-
private String endpoint
81-
82-
// Platform access token to use for requests
83-
private String accessToken
84-
8549
// Platform workflowId
8650
private String workspaceId
8751

8852
// Platform workflowId
8953
private String workflowId
9054

55+
// Platform client to handle all the requests
56+
private TowerClient client
57+
9158
TowerFusionToken() {
9259
final config = PlatformHelper.config()
9360
final env = SysEnv.get()
94-
this.endpoint = PlatformHelper.getEndpoint(config, env)
95-
this.accessToken = PlatformHelper.getAccessToken(config, env)
9661
this.workflowId = env.get('TOWER_WORKFLOW_ID')
9762
this.workspaceId = PlatformHelper.getWorkspaceId(config, env)
98-
}
99-
100-
protected void validateConfig() {
101-
if( !endpoint )
102-
throw new IllegalArgumentException("Missing Seqera Platform endpoint")
103-
if( !accessToken )
104-
throw new IllegalArgumentException("Missing Seqera Platform access token")
63+
this.client = TowerFactory.client()
10564
}
10665

10766
/**
@@ -125,7 +84,6 @@ class TowerFusionToken implements FusionToken {
12584
}
12685

12786
protected Map<String,String> getEnvironment0(String scheme, FusionConfig config) {
128-
validateConfig()
12987
final product = config.sku()
13088
final version = config.version()
13189
final token = getLicenseToken(product, version)
@@ -178,102 +136,6 @@ class TowerFusionToken implements FusionToken {
178136
* Helper methods
179137
*************************************************************************/
180138

181-
/**
182-
* Create a new HttpClient instance with default settings
183-
* @return The new HttpClient instance
184-
*/
185-
private static HttpClient newDefaultHttpClient() {
186-
final builder = HttpClient.newBuilder()
187-
.version(HttpClient.Version.HTTP_1_1)
188-
.followRedirects(HttpClient.Redirect.NEVER)
189-
.cookieHandler(new CookieManager())
190-
.connectTimeout(DEFAULT_CONNECTION_TIMEOUT)
191-
// use virtual threads executor if enabled
192-
if ( Threads.useVirtual() ) {
193-
builder.executor(Executors.newVirtualThreadPerTaskExecutor())
194-
}
195-
// build and return the new client
196-
return builder.build()
197-
}
198-
199-
/**
200-
* Create a new RetryPolicy instance with default settings and the given list of retryable errors. With this policy,
201-
* a request is retried on IOExceptions and any server errors defined in errorsToRetry. The number of retries, delay,
202-
* max delay, and jitter are controlled by the corresponding values defined at class level.
203-
*
204-
* @return The new RetryPolicy instance
205-
*/
206-
private static <T> RetryPolicy<HttpResponse<T>> newDefaultRetryPolicy(List<Integer> errorsToRetry) {
207-
208-
final retryOnException = (e -> e instanceof IOException) as Predicate<? extends Throwable>
209-
final retryOnStatusCode = ((HttpResponse<T> resp) -> resp.statusCode() in errorsToRetry) as Predicate<HttpResponse<T>>
210-
211-
final listener = new EventListener<ExecutionAttemptedEvent<HttpResponse<T>>>() {
212-
@Override
213-
void accept(ExecutionAttemptedEvent event) throws Throwable {
214-
def msg = "connection failure - attempt: ${event.attemptCount}"
215-
if (event.lastResult != null)
216-
msg += "; response: ${event.lastResult}"
217-
if (event.lastFailure != null)
218-
msg += "; exception: [${event.lastFailure.class.name}] ${event.lastFailure.message}"
219-
log.debug(msg)
220-
}
221-
}
222-
return RetryPolicy.<HttpResponse<T>> builder()
223-
.handleIf(retryOnException)
224-
.handleResultIf(retryOnStatusCode)
225-
.withBackoff(DEFAULT_RETRY_POLICY_DELAY.toMillis(), DEFAULT_RETRY_POLICY_MAX_DELAY.toMillis(), ChronoUnit.MILLIS)
226-
.withMaxAttempts(DEFAULT_RETRY_POLICY_MAX_ATTEMPTS)
227-
.withJitter(DEFAULT_RETRY_POLICY_JITTER)
228-
.onRetry(listener)
229-
.build()
230-
}
231-
232-
/**
233-
* Send an HTTP request and return the response. This method automatically retries the request according to the
234-
* given RetryPolicy.
235-
*
236-
* @param req The HttpRequest to send
237-
* @return The HttpResponse received
238-
*/
239-
private <T> HttpResponse<String> safeHttpSend(HttpRequest req, RetryPolicy<T> policy) {
240-
return Failsafe.with(policy).get(
241-
() -> {
242-
log.debug "Http request: method=${req.method()}; uri=${req.uri()}; request=${req}"
243-
final resp = httpClient.send(req, HttpResponse.BodyHandlers.ofString())
244-
log.debug "Http response: statusCode=${resp.statusCode()}; body=${resp.body()}"
245-
return resp
246-
} as CheckedSupplier
247-
) as HttpResponse<String>
248-
}
249-
250-
/**
251-
* Create a {@link HttpRequest} representing a {@link GetLicenseTokenRequest} object
252-
*
253-
* @param req The LicenseTokenRequest object
254-
* @return The resulting HttpRequest object
255-
*/
256-
private HttpRequest makeHttpRequest(GetLicenseTokenRequest req) {
257-
final body = HttpRequest.BodyPublishers.ofString( GsonHelper.toJson(req) )
258-
return HttpRequest.newBuilder()
259-
.uri(URI.create("${endpoint}/${LICENSE_TOKEN_PATH}").normalize())
260-
.header('Content-Type', 'application/json')
261-
.header('Traceparent', TraceUtils.rndTrace())
262-
.header('Authorization', "Bearer ${accessToken}")
263-
.POST(body)
264-
.build()
265-
}
266-
267-
/**
268-
* Serialize a {@link GetLicenseTokenRequest} object into a JSON string
269-
*
270-
* @param req The LicenseTokenRequest object
271-
* @return The resulting JSON string
272-
*/
273-
private static String serializeToJson(GetLicenseTokenRequest req) {
274-
return new Gson().toJson(req)
275-
}
276-
277139
/**
278140
* Parse a JSON string into a {@link GetLicenseTokenResponse} object
279141
*
@@ -299,24 +161,18 @@ class TowerFusionToken implements FusionToken {
299161
*/
300162
private GetLicenseTokenResponse sendRequest(GetLicenseTokenRequest req) throws AbortOperationException, UnauthorizedException, BadResponseException, IllegalStateException {
301163

302-
final httpReq = makeHttpRequest(req)
303-
304-
try {
305-
final resp = safeHttpSend(httpReq, retryPolicy)
306-
307-
if( resp.statusCode() == 200 ) {
308-
final ret = parseLicenseTokenResponse(resp.body())
309-
return ret
310-
}
164+
final url = "${client.getEndpoint()}/${LICENSE_TOKEN_PATH}"
165+
final resp = client.sendHttpMessage(url, req.toMap())
311166

312-
if( resp.statusCode() == 401 ) {
313-
throw new UnauthorizedException("Unauthorized [401] - Verify you have provided a Seqera Platform valid access token")
314-
}
315-
316-
throw new BadResponseException("Invalid response: ${httpReq.method()} ${httpReq.uri()} [${resp.statusCode()}] ${resp.body()}")
167+
if( resp.code == 200 ) {
168+
final ret = parseLicenseTokenResponse(resp.message)
169+
return ret
317170
}
318-
catch (IOException e) {
319-
throw new IllegalStateException("Unable to send request to '${httpReq.uri()}' : ${e.message}")
171+
172+
if( resp.code == 401 ) {
173+
throw new UnauthorizedException("Unauthorized [401] - Verify you have provided a Seqera Platform valid access token")
320174
}
175+
176+
throw new BadResponseException("Invalid response: ${url} [${resp.code}] ${resp.message} -- ${resp.cause}")
321177
}
322178
}

plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/GetLicenseTokenRequest.groovy

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,16 @@ class GetLicenseTokenRequest {
2929
* The Platform workspace ID associated with this request
3030
*/
3131
String workspaceId
32+
33+
/**
34+
* @return a Map representation of the request
35+
*/
36+
Map<String, String> toMap() {
37+
final map = new HashMap<String, String>()
38+
map.product = this.product
39+
map.version = this.version
40+
map.workflowId = this.workflowId
41+
map.workspaceId = this.workspaceId
42+
return map
43+
}
3244
}

0 commit comments

Comments
 (0)