2222import io .netty .channel .Channel ;
2323import io .netty .channel .EventLoopGroup ;
2424
25+ import java .util .HashMap ;
26+ import java .util .Iterator ;
27+ import java .util .Map ;
2528import java .util .Set ;
2629import java .util .concurrent .CompletableFuture ;
2730import java .util .concurrent .CompletionException ;
2831import java .util .concurrent .CompletionStage ;
29- import java .util .concurrent .ConcurrentHashMap ;
30- import java .util .concurrent .ConcurrentMap ;
3132import java .util .concurrent .TimeUnit ;
3233import java .util .concurrent .TimeoutException ;
3334import java .util .concurrent .atomic .AtomicBoolean ;
35+ import java .util .concurrent .locks .Lock ;
36+ import java .util .concurrent .locks .ReadWriteLock ;
37+ import java .util .concurrent .locks .ReentrantReadWriteLock ;
38+ import java .util .function .Supplier ;
3439
3540import org .neo4j .driver .Logger ;
3641import org .neo4j .driver .Logging ;
@@ -61,7 +66,8 @@ public class ConnectionPoolImpl implements ConnectionPool
6166 private final MetricsListener metricsListener ;
6267 private final boolean ownsEventLoopGroup ;
6368
64- private final ConcurrentMap <BoltServerAddress ,ExtendedChannelPool > pools = new ConcurrentHashMap <>();
69+ private final ReadWriteLock addressToPoolLock = new ReentrantReadWriteLock ();
70+ private final Map <BoltServerAddress ,ExtendedChannelPool > addressToPool = new HashMap <>();
6571 private final AtomicBoolean closed = new AtomicBoolean ();
6672 private final CompletableFuture <Void > closeFuture = new CompletableFuture <>();
6773 private final ConnectionFactory connectionFactory ;
@@ -124,25 +130,32 @@ public CompletionStage<Connection> acquire( BoltServerAddress address )
124130 @ Override
125131 public void retainAll ( Set <BoltServerAddress > addressesToRetain )
126132 {
127- for ( BoltServerAddress address : pools . keySet () )
133+ executeWithLock ( addressToPoolLock . writeLock (), () ->
128134 {
129- if ( !addressesToRetain .contains ( address ) )
135+ Iterator <Map .Entry <BoltServerAddress ,ExtendedChannelPool >> entryIterator = addressToPool .entrySet ().iterator ();
136+ while ( entryIterator .hasNext () )
130137 {
131- int activeChannels = nettyChannelTracker .inUseChannelCount ( address );
132- if ( activeChannels == 0 )
138+ Map .Entry <BoltServerAddress ,ExtendedChannelPool > entry = entryIterator .next ();
139+ BoltServerAddress address = entry .getKey ();
140+ if ( !addressesToRetain .contains ( address ) )
133141 {
134- // address is not present in updated routing table and has no active connections
135- // it's now safe to terminate corresponding connection pool and forget about it
136- ExtendedChannelPool pool = pools .remove ( address );
137- if ( pool != null )
142+ int activeChannels = nettyChannelTracker .inUseChannelCount ( address );
143+ if ( activeChannels == 0 )
138144 {
139- log .info ( "Closing connection pool towards %s, it has no active connections " +
140- "and is not in the routing table registry." , address );
141- closePoolInBackground ( address , pool );
145+ // address is not present in updated routing table and has no active connections
146+ // it's now safe to terminate corresponding connection pool and forget about it
147+ ExtendedChannelPool pool = entry .getValue ();
148+ entryIterator .remove ();
149+ if ( pool != null )
150+ {
151+ log .info ( "Closing connection pool towards %s, it has no active connections " +
152+ "and is not in the routing table registry." , address );
153+ closePoolInBackground ( address , pool );
154+ }
142155 }
143156 }
144157 }
145- }
158+ } );
146159 }
147160
148161 @ Override
@@ -163,35 +176,40 @@ public CompletionStage<Void> close()
163176 if ( closed .compareAndSet ( false , true ) )
164177 {
165178 nettyChannelTracker .prepareToCloseChannels ();
166- CompletableFuture <Void > allPoolClosedFuture = closeAllPools ();
167179
168- // We can only shutdown event loop group when all netty pools are fully closed,
169- // otherwise the netty pools might missing threads (from event loop group) to execute clean ups.
170- allPoolClosedFuture .whenComplete ( ( ignored , pollCloseError ) -> {
171- pools .clear ();
172- if ( !ownsEventLoopGroup )
173- {
174- completeWithNullIfNoError ( closeFuture , pollCloseError );
175- }
176- else
177- {
178- shutdownEventLoopGroup ( pollCloseError );
179- }
180- } );
180+ executeWithLockAsync ( addressToPoolLock .writeLock (),
181+ () ->
182+ {
183+ // We can only shutdown event loop group when all netty pools are fully closed,
184+ // otherwise the netty pools might missing threads (from event loop group) to execute clean ups.
185+ return closeAllPools ().whenComplete (
186+ ( ignored , pollCloseError ) ->
187+ {
188+ addressToPool .clear ();
189+ if ( !ownsEventLoopGroup )
190+ {
191+ completeWithNullIfNoError ( closeFuture , pollCloseError );
192+ }
193+ else
194+ {
195+ shutdownEventLoopGroup ( pollCloseError );
196+ }
197+ } );
198+ } );
181199 }
182200 return closeFuture ;
183201 }
184202
185203 @ Override
186204 public boolean isOpen ( BoltServerAddress address )
187205 {
188- return pools . containsKey ( address );
206+ return executeWithLock ( addressToPoolLock . readLock (), () -> addressToPool . containsKey ( address ) );
189207 }
190208
191209 @ Override
192210 public String toString ()
193211 {
194- return "ConnectionPoolImpl{" + "pools=" + pools + '}' ;
212+ return executeWithLock ( addressToPoolLock . readLock (), () -> "ConnectionPoolImpl{" + "pools=" + addressToPool + '}' ) ;
195213 }
196214
197215 private void processAcquisitionError ( ExtendedChannelPool pool , BoltServerAddress serverAddress , Throwable error )
@@ -237,15 +255,15 @@ private void assertNotClosed( BoltServerAddress address, Channel channel, Extend
237255 {
238256 pool .release ( channel );
239257 closePoolInBackground ( address , pool );
240- pools . remove ( address );
258+ executeWithLock ( addressToPoolLock . writeLock (), () -> addressToPool . remove ( address ) );
241259 assertNotClosed ();
242260 }
243261 }
244262
245263 // for testing only
246264 ExtendedChannelPool getPool ( BoltServerAddress address )
247265 {
248- return pools . get ( address );
266+ return executeWithLock ( addressToPoolLock . readLock (), () -> addressToPool . get ( address ) );
249267 }
250268
251269 ExtendedChannelPool newPool ( BoltServerAddress address )
@@ -256,12 +274,22 @@ ExtendedChannelPool newPool( BoltServerAddress address )
256274
257275 private ExtendedChannelPool getOrCreatePool ( BoltServerAddress address )
258276 {
259- return pools .computeIfAbsent ( address , ignored -> {
260- ExtendedChannelPool pool = newPool ( address );
261- // before the connection pool is added I can add the metrics for the pool.
262- metricsListener .putPoolMetrics ( pool .id (), address , this );
263- return pool ;
264- } );
277+ ExtendedChannelPool existingPool = executeWithLock ( addressToPoolLock .readLock (), () -> addressToPool .get ( address ) );
278+ return existingPool != null
279+ ? existingPool
280+ : executeWithLock ( addressToPoolLock .writeLock (),
281+ () ->
282+ {
283+ ExtendedChannelPool pool = addressToPool .get ( address );
284+ if ( pool == null )
285+ {
286+ pool = newPool ( address );
287+ // before the connection pool is added I can add the metrics for the pool.
288+ metricsListener .putPoolMetrics ( pool .id (), address , this );
289+ addressToPool .put ( address , pool );
290+ }
291+ return pool ;
292+ } );
265293 }
266294
267295 private CompletionStage <Void > closePool ( ExtendedChannelPool pool )
@@ -303,12 +331,45 @@ private void shutdownEventLoopGroup( Throwable pollCloseError )
303331 private CompletableFuture <Void > closeAllPools ()
304332 {
305333 return CompletableFuture .allOf (
306- pools .entrySet ().stream ().map ( entry -> {
307- BoltServerAddress address = entry .getKey ();
308- ExtendedChannelPool pool = entry .getValue ();
309- log .info ( "Closing connection pool towards %s" , address );
310- // Wait for all pools to be closed.
311- return closePool ( pool ).toCompletableFuture ();
312- } ).toArray ( CompletableFuture []::new ) );
334+ addressToPool .entrySet ().stream ()
335+ .map ( entry ->
336+ {
337+ BoltServerAddress address = entry .getKey ();
338+ ExtendedChannelPool pool = entry .getValue ();
339+ log .info ( "Closing connection pool towards %s" , address );
340+ // Wait for all pools to be closed.
341+ return closePool ( pool ).toCompletableFuture ();
342+ } )
343+ .toArray ( CompletableFuture []::new ) );
344+ }
345+
346+ private void executeWithLock ( Lock lock , Runnable runnable )
347+ {
348+ executeWithLock ( lock , () ->
349+ {
350+ runnable .run ();
351+ return null ;
352+ } );
353+ }
354+
355+ private <T > T executeWithLock ( Lock lock , Supplier <T > supplier )
356+ {
357+ lock .lock ();
358+ try
359+ {
360+ return supplier .get ();
361+ }
362+ finally
363+ {
364+ lock .unlock ();
365+ }
366+ }
367+
368+ private <T > void executeWithLockAsync ( Lock lock , Supplier <CompletionStage <T >> stageSupplier )
369+ {
370+ lock .lock ();
371+ CompletableFuture .completedFuture ( lock )
372+ .thenCompose ( ignored -> stageSupplier .get () )
373+ .whenComplete ( ( ignored , throwable ) -> lock .unlock () );
313374 }
314375}
0 commit comments