20
20
import java .sql .Connection ;
21
21
import java .sql .SQLException ;
22
22
import java .util .Arrays ;
23
+ import java .util .HashSet ;
23
24
import java .util .Map ;
24
25
import java .util .Objects ;
25
- import java .util .Optional ;
26
26
import java .util .Queue ;
27
27
import java .util .Set ;
28
28
import java .util .concurrent .ConcurrentHashMap ;
29
29
import java .util .concurrent .ConcurrentLinkedQueue ;
30
+ import java .util .concurrent .Executor ;
30
31
import java .util .concurrent .ExecutorService ;
31
32
import java .util .concurrent .Executors ;
32
33
import java .util .logging .Logger ;
35
36
import software .amazon .jdbc .util .Messages ;
36
37
import software .amazon .jdbc .util .RdsUtils ;
37
38
import software .amazon .jdbc .util .StringUtils ;
39
+ import software .amazon .jdbc .util .SynchronousExecutor ;
38
40
import software .amazon .jdbc .util .telemetry .TelemetryContext ;
39
41
import software .amazon .jdbc .util .telemetry .TelemetryFactory ;
40
42
import software .amazon .jdbc .util .telemetry .TelemetryTraceLevel ;
@@ -50,16 +52,17 @@ public class OpenedConnectionTracker {
50
52
invalidateThread .setDaemon (true );
51
53
return invalidateThread ;
52
54
});
53
- private static final ExecutorService abortConnectionExecutorService =
54
- Executors .newCachedThreadPool (
55
- r -> {
56
- final Thread abortThread = new Thread (r );
57
- abortThread .setDaemon (true );
58
- return abortThread ;
59
- });
55
+ private static final Executor abortConnectionExecutor = new SynchronousExecutor ();
60
56
61
57
private static final Logger LOGGER = Logger .getLogger (OpenedConnectionTracker .class .getName ());
62
58
private static final RdsUtils rdsUtils = new RdsUtils ();
59
+
60
+ private static final Set <String > safeToCheckClosedClasses = new HashSet <>(Arrays .asList (
61
+ "HikariProxyConnection" ,
62
+ "org.postgresql.jdbc.PgConnection" ,
63
+ "com.mysql.cj.jdbc.ConnectionImpl" ,
64
+ "org.mariadb.jdbc.Connection" ));
65
+
63
66
private final PluginService pluginService ;
64
67
65
68
public OpenedConnectionTracker (final PluginService pluginService ) {
@@ -72,6 +75,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
72
75
// Check if the connection was established using an instance endpoint
73
76
if (rdsUtils .isRdsInstance (hostSpec .getHost ())) {
74
77
trackConnection (hostSpec .getHostAndPort (), conn );
78
+ logOpenedConnections ();
75
79
return ;
76
80
}
77
81
@@ -80,14 +84,17 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
80
84
.max (String ::compareToIgnoreCase )
81
85
.orElse (null );
82
86
83
- if (instanceEndpoint == null ) {
84
- LOGGER .finest (
85
- Messages .get ("OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue" ,
86
- new Object [] {hostSpec .getHost ()}));
87
+ if (instanceEndpoint != null ) {
88
+ trackConnection (instanceEndpoint , conn );
89
+ logOpenedConnections ();
87
90
return ;
88
91
}
89
92
90
- trackConnection (instanceEndpoint , conn );
93
+ // It seems there's no RDS instance host found. It might be a custom domain name. Let's track by all aliases
94
+ for (String alias : aliases ) {
95
+ trackConnection (alias , conn );
96
+ }
97
+ logOpenedConnections ();
91
98
}
92
99
93
100
/**
@@ -100,28 +107,27 @@ public void invalidateAllConnections(final HostSpec hostSpec) {
100
107
invalidateAllConnections (hostSpec .getAliases ().toArray (new String [] {}));
101
108
}
102
109
103
- public void invalidateAllConnections (final String ... node ) {
110
+ public void invalidateAllConnections (final String ... keys ) {
104
111
TelemetryFactory telemetryFactory = this .pluginService .getTelemetryFactory ();
105
112
TelemetryContext telemetryContext = telemetryFactory .openTelemetryContext (
106
113
TELEMETRY_INVALIDATE_CONNECTIONS , TelemetryTraceLevel .NESTED );
107
114
108
115
try {
109
- final Optional <String > instanceEndpoint = Arrays .stream (node )
110
- .filter (x -> rdsUtils .isRdsInstance (rdsUtils .removePort (x )))
111
- .findFirst ();
112
- if (!instanceEndpoint .isPresent ()) {
113
- return ;
116
+ for (String key : keys ) {
117
+ try {
118
+ final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (key );
119
+ logConnectionQueue (key , connectionQueue );
120
+ invalidateConnections (connectionQueue );
121
+ } catch (Exception ex ) {
122
+ // ignore and continue
123
+ }
114
124
}
115
- final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (instanceEndpoint .get ());
116
- logConnectionQueue (instanceEndpoint .get (), connectionQueue );
117
- invalidateConnections (openedConnections .get (instanceEndpoint .get ()));
118
-
119
125
} finally {
120
126
telemetryContext .closeContext ();
121
127
}
122
128
}
123
129
124
- public void invalidateCurrentConnection (final HostSpec hostSpec , final Connection connection ) {
130
+ public void removeConnectionTracking (final HostSpec hostSpec , final Connection connection ) {
125
131
final String host = rdsUtils .isRdsInstance (hostSpec .getHost ())
126
132
? hostSpec .asAlias ()
127
133
: hostSpec .getAliases ().stream ()
@@ -134,8 +140,11 @@ public void invalidateCurrentConnection(final HostSpec hostSpec, final Connectio
134
140
}
135
141
136
142
final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (host );
137
- logConnectionQueue (host , connectionQueue );
138
- connectionQueue .removeIf (connectionWeakReference -> Objects .equals (connectionWeakReference .get (), connection ));
143
+ if (connectionQueue != null ) {
144
+ logConnectionQueue (host , connectionQueue );
145
+ connectionQueue .removeIf (connectionWeakReference -> connectionWeakReference != null
146
+ && Objects .equals (connectionWeakReference .get (), connection ));
147
+ }
139
148
}
140
149
141
150
private void trackConnection (final String instanceEndpoint , final Connection connection ) {
@@ -144,10 +153,12 @@ private void trackConnection(final String instanceEndpoint, final Connection con
144
153
instanceEndpoint ,
145
154
(k ) -> new ConcurrentLinkedQueue <>());
146
155
connectionQueue .add (new WeakReference <>(connection ));
147
- logOpenedConnections ();
148
156
}
149
157
150
158
private void invalidateConnections (final Queue <WeakReference <Connection >> connectionQueue ) {
159
+ if (connectionQueue == null || connectionQueue .isEmpty ()) {
160
+ return ;
161
+ }
151
162
invalidateConnectionsExecutorService .submit (() -> {
152
163
WeakReference <Connection > connReference ;
153
164
while ((connReference = connectionQueue .poll ()) != null ) {
@@ -157,7 +168,7 @@ private void invalidateConnections(final Queue<WeakReference<Connection>> connec
157
168
}
158
169
159
170
try {
160
- conn .abort (abortConnectionExecutorService );
171
+ conn .abort (abortConnectionExecutor );
161
172
} catch (final SQLException e ) {
162
173
// swallow this exception, current connection should be useless anyway.
163
174
}
@@ -204,7 +215,10 @@ public void pruneNullConnections() {
204
215
if (conn == null ) {
205
216
return true ;
206
217
}
207
- if (conn .getClass ().getSimpleName ().equals ("HikariProxyConnection" )) {
218
+ // The following classes do not check connection validity by calling a DB server
219
+ // so it's safe to check whether connection is already closed.
220
+ if (safeToCheckClosedClasses .contains (conn .getClass ().getSimpleName ())
221
+ || safeToCheckClosedClasses .contains (conn .getClass ().getName ())) {
208
222
try {
209
223
return conn .isClosed ();
210
224
} catch (SQLException ex ) {
0 commit comments