1111import java .util .concurrent .atomic .AtomicReference ;
1212import java .util .function .Function ;
1313
14- import org .slf4j .Logger ;
15- import org .slf4j .LoggerFactory ;
16-
1714import io .modelcontextprotocol .spec .McpClientSession ;
1815import io .modelcontextprotocol .spec .McpError ;
1916import io .modelcontextprotocol .spec .McpSchema ;
2017import io .modelcontextprotocol .spec .McpTransportSessionNotFoundException ;
2118import io .modelcontextprotocol .util .Assert ;
19+ import org .slf4j .Logger ;
20+ import org .slf4j .LoggerFactory ;
2221import reactor .core .publisher .Mono ;
2322import reactor .core .publisher .Sinks ;
2423import reactor .util .context .ContextView ;
@@ -99,21 +98,30 @@ class LifecycleInitializer {
9998 */
10099 private final Duration initializationTimeout ;
101100
101+ /**
102+ * Post-initialization hook to perform additional operations after every successful
103+ * initialization.
104+ */
105+ private final Function <Initialization , Mono <Void >> postInitializationHook ;
106+
102107 public LifecycleInitializer (McpSchema .ClientCapabilities clientCapabilities , McpSchema .Implementation clientInfo ,
103108 List <String > protocolVersions , Duration initializationTimeout ,
104- Function <ContextView , McpClientSession > sessionSupplier ) {
109+ Function <ContextView , McpClientSession > sessionSupplier ,
110+ Function <Initialization , Mono <Void >> postInitializationHook ) {
105111
106112 Assert .notNull (sessionSupplier , "Session supplier must not be null" );
107113 Assert .notNull (clientCapabilities , "Client capabilities must not be null" );
108114 Assert .notNull (clientInfo , "Client info must not be null" );
109115 Assert .notEmpty (protocolVersions , "Protocol versions must not be empty" );
110116 Assert .notNull (initializationTimeout , "Initialization timeout must not be null" );
117+ Assert .notNull (postInitializationHook , "Post-initialization hook must not be null" );
111118
112119 this .sessionSupplier = sessionSupplier ;
113120 this .clientCapabilities = clientCapabilities ;
114121 this .clientInfo = clientInfo ;
115122 this .protocolVersions = Collections .unmodifiableList (new ArrayList <>(protocolVersions ));
116123 this .initializationTimeout = initializationTimeout ;
124+ this .postInitializationHook = postInitializationHook ;
117125 }
118126
119127 /**
@@ -148,10 +156,6 @@ interface Initialization {
148156
149157 }
150158
151- /**
152- * Default implementation of the {@link Initialization} interface that manages the MCP
153- * client initialization process.
154- */
155159 private static class DefaultInitialization implements Initialization {
156160
157161 /**
@@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) {
199203 this .mcpClientSession .set (mcpClientSession );
200204 }
201205
202- /**
203- * Returns a Mono that completes when the MCP client initialization is complete.
204- * This allows subscribers to wait for the initialization to finish before
205- * proceeding with further operations.
206- * @return A Mono that emits the result of the MCP initialization process
207- */
208206 private Mono <McpSchema .InitializeResult > await () {
209207 return this .initSink .asMono ();
210208 }
211209
212- /**
213- * Completes the initialization process with the given result. It caches the
214- * result and emits it to all subscribers waiting for the initialization to
215- * complete.
216- * @param initializeResult The result of the MCP initialization process
217- */
218210 private void complete (McpSchema .InitializeResult initializeResult ) {
219- // first ensure the result is cached
220- this .result .set (initializeResult );
221211 // inform all the subscribers waiting for the initialization
222212 this .initSink .emitValue (initializeResult , Sinks .EmitFailureHandler .FAIL_FAST );
223213 }
224214
215+ private void cacheResult (McpSchema .InitializeResult initializeResult ) {
216+ // first ensure the result is cached
217+ this .result .set (initializeResult );
218+ }
219+
225220 private void error (Throwable t ) {
226221 this .initSink .emitError (t , Sinks .EmitFailureHandler .FAIL_FAST );
227222 }
@@ -263,7 +258,7 @@ public void handleException(Throwable t) {
263258 }
264259 // Providing an empty operation since we are only interested in triggering
265260 // the implicit initialization step.
266- withIntitialization ("re-initializing" , result -> Mono .empty ()).subscribe ();
261+ this . withInitialization ("re-initializing" , result -> Mono .empty ()).subscribe ();
267262 }
268263 }
269264
@@ -275,16 +270,16 @@ public void handleException(Throwable t) {
275270 * @param operation The operation to execute when the client is initialized
276271 * @return A Mono that completes with the result of the operation
277272 */
278- public <T > Mono <T > withIntitialization (String actionName , Function <Initialization , Mono <T >> operation ) {
273+ public <T > Mono <T > withInitialization (String actionName , Function <Initialization , Mono <T >> operation ) {
279274 return Mono .deferContextual (ctx -> {
280275 DefaultInitialization newInit = new DefaultInitialization ();
281276 DefaultInitialization previous = this .initializationRef .compareAndExchange (null , newInit );
282277
283278 boolean needsToInitialize = previous == null ;
284279 logger .debug (needsToInitialize ? "Initialization process started" : "Joining previous initialization" );
285280
286- Mono <McpSchema .InitializeResult > initializationJob = needsToInitialize ? doInitialize ( newInit , ctx )
287- : previous .await ();
281+ Mono <McpSchema .InitializeResult > initializationJob = needsToInitialize
282+ ? this . doInitialize ( newInit , this . postInitializationHook , ctx ) : previous .await ();
288283
289284 return initializationJob .map (initializeResult -> this .initializationRef .get ())
290285 .timeout (this .initializationTimeout )
@@ -296,7 +291,9 @@ public <T> Mono<T> withIntitialization(String actionName, Function<Initializatio
296291 });
297292 }
298293
299- private Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization initialization , ContextView ctx ) {
294+ private Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization initialization ,
295+ Function <Initialization , Mono <Void >> postInitOperation , ContextView ctx ) {
296+
300297 initialization .setMcpClientSession (this .sessionSupplier .apply (ctx ));
301298
302299 McpClientSession mcpClientSession = initialization .mcpSession ();
@@ -323,6 +320,9 @@ private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization init
323320
324321 return mcpClientSession .sendNotification (McpSchema .METHOD_NOTIFICATION_INITIALIZED , null )
325322 .thenReturn (initializeResult );
323+ }).flatMap (initializeResult -> {
324+ initialization .cacheResult (initializeResult );
325+ return postInitOperation .apply (initialization ).thenReturn (initializeResult );
326326 }).doOnNext (initialization ::complete ).onErrorResume (ex -> {
327327 initialization .error (ex );
328328 return Mono .error (ex );
0 commit comments