Skip to content

Commit 17b638e

Browse files
committed
additional improvements
1 parent aee2d1d commit 17b638e

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

aws/endpointcreds/provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func New(cfg aws.Config) *Provider {
8787
// Retrieve will attempt to request the credentials from the endpoint the Provider
8888
// was configured for. And error will be returned if the retrieval fails.
8989
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
90-
resp, err := p.getCredentials()
90+
resp, err := p.getCredentials(ctx)
9191
if err != nil {
9292
return aws.Credentials{},
9393
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
@@ -120,7 +120,7 @@ type errorOutput struct {
120120
Message string `json:"message"`
121121
}
122122

123-
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
123+
func (p *Provider) getCredentials(ctx context.Context) (*getCredentialsOutput, error) {
124124
op := &aws.Operation{
125125
Name: "GetCredentials",
126126
HTTPMethod: "GET",
@@ -129,7 +129,7 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
129129
out := &getCredentialsOutput{}
130130
req := p.Client.NewRequest(op, nil, out)
131131
req.HTTPRequest.Header.Set("Accept", "application/json")
132-
132+
req.SetContext(ctx)
133133
return out, req.Send()
134134
}
135135

aws/signer/v2/v2_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package v2
22

33
import (
44
"bytes"
5+
"context"
56
"net/http"
67
"net/url"
78
"os"
@@ -75,7 +76,7 @@ func TestSignRequestWithAndWithoutSession(t *testing.T) {
7576

7677
signer := builder.BuildSigner()
7778

78-
err := signer.Sign(nil)
79+
err := signer.Sign(context.Background())
7980
if err != nil {
8081
t.Fatalf("expect no error, got %v", err)
8182
}
@@ -122,7 +123,7 @@ func TestSignRequestWithAndWithoutSession(t *testing.T) {
122123
builder.SessionToken = "SESSION"
123124
signer = builder.BuildSigner()
124125

125-
err = signer.Sign(nil)
126+
err = signer.Sign(context.Background())
126127
if err != nil {
127128
t.Fatalf("expect no error, got %v", err)
128129
}
@@ -165,7 +166,7 @@ func TestMoreComplexSignRequest(t *testing.T) {
165166

166167
signer := builder.BuildSigner()
167168

168-
err := signer.Sign(nil)
169+
err := signer.Sign(context.Background())
169170
if err != nil {
170171
t.Fatalf("expect no error, got %v", err)
171172
}

aws/static_provider.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package aws
22

33
import (
44
"context"
5+
56
"github.com/aws/aws-sdk-go-v2/aws/awserr"
67
)
78

@@ -33,6 +34,12 @@ func NewStaticCredentialsProvider(key, secret, session string) StaticCredentials
3334

3435
// Retrieve returns the credentials or error if the credentials are invalid.
3536
func (s StaticCredentialsProvider) Retrieve(ctx context.Context) (Credentials, error) {
37+
select {
38+
case <-ctx.Done():
39+
return Credentials{}, awserr.New(ErrCodeRequestCanceled, "context canceled", ctx.Err())
40+
default: // do nothing
41+
}
42+
3643
v := s.Value
3744
if v.AccessKeyID == "" || v.SecretAccessKey == "" {
3845
return Credentials{Source: StaticCredentialsProviderName}, ErrStaticCredentialsEmpty

aws/stscreds/provider.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,6 @@ func (p *AssumeRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, e
248248
}
249249

250250
req := p.Client.AssumeRoleRequest(input)
251-
if ctx == nil {
252-
ctx = context.Background()
253-
}
254251
resp, err := req.Send(ctx)
255252
if err != nil {
256253
return aws.Credentials{Source: ProviderName}, err

0 commit comments

Comments
 (0)