20
20
import java .sql .Connection ;
21
21
import java .sql .SQLException ;
22
22
import java .util .Arrays ;
23
+ import java .util .HashSet ;
23
24
import java .util .Map ;
24
25
import java .util .Objects ;
25
- import java .util .Optional ;
26
26
import java .util .Queue ;
27
27
import java .util .Set ;
28
28
import java .util .concurrent .ConcurrentHashMap ;
29
29
import java .util .concurrent .ConcurrentLinkedQueue ;
30
+ import java .util .concurrent .Executor ;
30
31
import java .util .concurrent .ExecutorService ;
31
32
import java .util .concurrent .Executors ;
32
33
import java .util .logging .Logger ;
35
36
import software .amazon .jdbc .util .Messages ;
36
37
import software .amazon .jdbc .util .RdsUtils ;
37
38
import software .amazon .jdbc .util .StringUtils ;
39
+ import software .amazon .jdbc .util .SynchronousExecutor ;
38
40
import software .amazon .jdbc .util .telemetry .TelemetryContext ;
39
41
import software .amazon .jdbc .util .telemetry .TelemetryFactory ;
40
42
import software .amazon .jdbc .util .telemetry .TelemetryTraceLevel ;
@@ -50,16 +52,17 @@ public class OpenedConnectionTracker {
50
52
invalidateThread .setDaemon (true );
51
53
return invalidateThread ;
52
54
});
53
- private static final ExecutorService abortConnectionExecutorService =
54
- Executors .newCachedThreadPool (
55
- r -> {
56
- final Thread abortThread = new Thread (r );
57
- abortThread .setDaemon (true );
58
- return abortThread ;
59
- });
55
+ private static final Executor abortConnectionExecutor = new SynchronousExecutor ();
60
56
61
57
private static final Logger LOGGER = Logger .getLogger (OpenedConnectionTracker .class .getName ());
62
58
private static final RdsUtils rdsUtils = new RdsUtils ();
59
+
60
+ private static final Set <String > safeCheckIfClosed = new HashSet <>(Arrays .asList (
61
+ "HikariProxyConnection" ,
62
+ "org.postgresql.jdbc.PgConnection" ,
63
+ "com.mysql.cj.jdbc.ConnectionImpl" ,
64
+ "org.mariadb.jdbc.Connection" ));
65
+
63
66
private final PluginService pluginService ;
64
67
65
68
public OpenedConnectionTracker (final PluginService pluginService ) {
@@ -72,6 +75,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
72
75
// Check if the connection was established using an instance endpoint
73
76
if (rdsUtils .isRdsInstance (hostSpec .getHost ())) {
74
77
trackConnection (hostSpec .getHostAndPort (), conn );
78
+ logOpenedConnections ();
75
79
return ;
76
80
}
77
81
@@ -80,14 +84,17 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
80
84
.max (String ::compareToIgnoreCase )
81
85
.orElse (null );
82
86
83
- if (instanceEndpoint == null ) {
84
- LOGGER .finest (
85
- Messages .get ("OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue" ,
86
- new Object [] {hostSpec .getHost ()}));
87
+ if (instanceEndpoint != null ) {
88
+ trackConnection (instanceEndpoint , conn );
89
+ logOpenedConnections ();
87
90
return ;
88
91
}
89
92
90
- trackConnection (instanceEndpoint , conn );
93
+ // It seems there's no RDS instance host found. It might be a custom domain name. Let's track by all aliases
94
+ for (String alias : aliases ) {
95
+ trackConnection (alias , conn );
96
+ }
97
+ logOpenedConnections ();
91
98
}
92
99
93
100
/**
@@ -100,22 +107,21 @@ public void invalidateAllConnections(final HostSpec hostSpec) {
100
107
invalidateAllConnections (hostSpec .getAliases ().toArray (new String [] {}));
101
108
}
102
109
103
- public void invalidateAllConnections (final String ... node ) {
110
+ public void invalidateAllConnections (final String ... keys ) {
104
111
TelemetryFactory telemetryFactory = this .pluginService .getTelemetryFactory ();
105
112
TelemetryContext telemetryContext = telemetryFactory .openTelemetryContext (
106
113
TELEMETRY_INVALIDATE_CONNECTIONS , TelemetryTraceLevel .NESTED );
107
114
108
115
try {
109
- final Optional <String > instanceEndpoint = Arrays .stream (node )
110
- .filter (x -> rdsUtils .isRdsInstance (rdsUtils .removePort (x )))
111
- .findFirst ();
112
- if (!instanceEndpoint .isPresent ()) {
113
- return ;
116
+ for (String key : keys ) {
117
+ try {
118
+ final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (key );
119
+ logConnectionQueue (key , connectionQueue );
120
+ invalidateConnections (connectionQueue );
121
+ } catch (Exception ex ) {
122
+ // ignore and continue
123
+ }
114
124
}
115
- final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (instanceEndpoint .get ());
116
- logConnectionQueue (instanceEndpoint .get (), connectionQueue );
117
- invalidateConnections (openedConnections .get (instanceEndpoint .get ()));
118
-
119
125
} finally {
120
126
telemetryContext .closeContext ();
121
127
}
@@ -134,8 +140,11 @@ public void invalidateCurrentConnection(final HostSpec hostSpec, final Connectio
134
140
}
135
141
136
142
final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (host );
137
- logConnectionQueue (host , connectionQueue );
138
- connectionQueue .removeIf (connectionWeakReference -> Objects .equals (connectionWeakReference .get (), connection ));
143
+ if (connectionQueue != null ) {
144
+ logConnectionQueue (host , connectionQueue );
145
+ connectionQueue .removeIf (connectionWeakReference -> connectionWeakReference != null
146
+ && Objects .equals (connectionWeakReference .get (), connection ));
147
+ }
139
148
}
140
149
141
150
private void trackConnection (final String instanceEndpoint , final Connection connection ) {
@@ -144,7 +153,6 @@ private void trackConnection(final String instanceEndpoint, final Connection con
144
153
instanceEndpoint ,
145
154
(k ) -> new ConcurrentLinkedQueue <>());
146
155
connectionQueue .add (new WeakReference <>(connection ));
147
- logOpenedConnections ();
148
156
}
149
157
150
158
private void invalidateConnections (final Queue <WeakReference <Connection >> connectionQueue ) {
@@ -157,7 +165,7 @@ private void invalidateConnections(final Queue<WeakReference<Connection>> connec
157
165
}
158
166
159
167
try {
160
- conn .abort (abortConnectionExecutorService );
168
+ conn .abort (abortConnectionExecutor );
161
169
} catch (final SQLException e ) {
162
170
// swallow this exception, current connection should be useless anyway.
163
171
}
@@ -204,7 +212,10 @@ public void pruneNullConnections() {
204
212
if (conn == null ) {
205
213
return true ;
206
214
}
207
- if (conn .getClass ().getSimpleName ().equals ("HikariProxyConnection" )) {
215
+ // The following classes do not check connection validity by calling a DB server
216
+ // so it's safe to check whether connection is already closed.
217
+ if (safeCheckIfClosed .contains (conn .getClass ().getSimpleName ())
218
+ || safeCheckIfClosed .contains (conn .getClass ().getName ())) {
208
219
try {
209
220
return conn .isClosed ();
210
221
} catch (SQLException ex ) {
0 commit comments