@@ -283,15 +283,57 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
283283 return cn , nil
284284}
285285
286- func (c * baseClient ) reAuth (ctx context.Context , cn * Conn , credentials auth.Credentials ) error {
287- var err error
288- username , password := credentials .BasicAuth ()
289- if username != "" {
290- err = cn .AuthACL (ctx , username , password ).Err ()
291- } else {
292- err = cn .Auth (ctx , password ).Err ()
286+ func (c * baseClient ) newReAuthCredentialsListener (ctx context.Context , conn * Conn ) auth.CredentialsListener {
287+ return auth .NewReAuthCredentialsListener (
288+ c .reAuthConnection (c .context (ctx ), conn ),
289+ c .onAuthenticationErr (c .context (ctx ), conn ),
290+ )
291+ }
292+
293+ func (c * baseClient ) reAuthConnection (ctx context.Context , cn * Conn ) func (credentials auth.Credentials ) error {
294+ return func (credentials auth.Credentials ) error {
295+ var err error
296+ username , password := credentials .BasicAuth ()
297+ if username != "" {
298+ err = cn .AuthACL (ctx , username , password ).Err ()
299+ } else {
300+ err = cn .Auth (ctx , password ).Err ()
301+ }
302+ return err
303+ }
304+ }
305+ func (c * baseClient ) onAuthenticationErr (ctx context.Context , cn * Conn ) func (err error ) {
306+ return func (err error ) {
307+ // since the connection pool of the *Conn will actually return us the underlying pool.Conn,
308+ // we can get it from the *Conn and remove it from the clients pool.
309+ if err != nil {
310+ if isBadConn (err , false , c .opt .Addr ) {
311+ poolCn , _ := cn .connPool .Get (ctx )
312+ c .connPool .Remove (ctx , poolCn , err )
313+ }
314+ }
315+ }
316+ }
317+
318+ func (c * baseClient ) wrappedOnClose (newOnClose func () error ) func () error {
319+ onClose := c .onClose
320+ return func () error {
321+ var firstErr error
322+ err := newOnClose ()
323+ // Even if we have an error we would like to execute the onClose hook
324+ // if it exists. We will return the first error that occurred.
325+ // This is to keep error handling consistent with the rest of the code.
326+ if err != nil {
327+ firstErr = err
328+ }
329+ if onClose != nil {
330+ err = onClose ()
331+ if err != nil && firstErr == nil {
332+ firstErr = err
333+ }
334+ }
335+ return firstErr
293336 }
294- return err
295337}
296338
297339func (c * baseClient ) initConn (ctx context.Context , cn * pool.Conn ) error {
@@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
312354
313355 var authenticated bool
314356 username , password := c .opt .Username , c .opt .Password
315- if c .opt .CredentialsProviderContext != nil {
357+ if c .opt .StreamingCredentialsProvider != nil {
358+ credentials , cancelCredentialsProvider , err := c .opt .StreamingCredentialsProvider .
359+ Subscribe (c .newReAuthCredentialsListener (ctx , conn ))
360+ if err != nil {
361+ return err
362+ }
363+ c .onClose = c .wrappedOnClose (cancelCredentialsProvider )
364+ username , password = credentials .BasicAuth ()
365+ } else if c .opt .CredentialsProviderContext != nil {
316366 if username , password , err = c .opt .CredentialsProviderContext (ctx ); err != nil {
317367 return err
318368 }
@@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
336386 }
337387
338388 if ! authenticated && password != "" {
339- err = c .reAuth (ctx , conn , auth .NewBasicCredentials (username , password ))
389+ err = c .reAuthConnection (ctx , conn )( auth .NewBasicCredentials (username , password ))
340390 if err != nil {
341391 return err
342392 }
0 commit comments