Skip to content

Commit 3f57866

Browse files
committed
adds context to the V4 signer's utilities
1 parent d97f6a6 commit 3f57866

File tree

9 files changed

+104
-90
lines changed

9 files changed

+104
-90
lines changed

aws/signer/v4/functional_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package v4_test
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/url"
@@ -162,7 +163,7 @@ func TestStandaloneSign_CustomURIEscape(t *testing.T) {
162163
req.URL.Path = `/log-*/_search`
163164
req.URL.Opaque = "//subdomain.us-east-1.es.amazonaws.com/log-%2A/_search"
164165

165-
_, err = signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
166+
_, err = signer.Sign(context.Background(), req, nil, "es", "us-east-1", time.Unix(0, 0))
166167
if err != nil {
167168
t.Fatalf("expect no error, got %v", err)
168169
}
@@ -191,7 +192,7 @@ func TestStandaloneSign(t *testing.T) {
191192
req.URL.Path = c.OrigURI
192193
req.URL.RawQuery = c.OrigQuery
193194

194-
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
195+
_, err = signer.Sign(context.Background(), req, nil, c.Service, c.Region, time.Unix(0, 0))
195196
if err != nil {
196197
t.Errorf("expected no error, but received %v", err)
197198
}
@@ -228,7 +229,7 @@ func TestStandaloneSign_RawPath(t *testing.T) {
228229
req.URL.RawPath = c.EscapedURI
229230
req.URL.RawQuery = c.OrigQuery
230231

231-
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
232+
_, err = signer.Sign(context.Background(), req, nil, c.Service, c.Region, time.Unix(0, 0))
232233
if err != nil {
233234
t.Errorf("expected no error, but received %v", err)
234235
}

aws/signer/v4/v4.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ type signingCtx struct {
263263
// generated. To bypass the signer computing the hash you can set the
264264
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
265265
// only compute the hash if the request header value is empty.
266-
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
267-
return v4.signWithBody(r, body, service, region, 0, signTime)
266+
func (v4 Signer) Sign(ctx context.Context, r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
267+
return v4.signWithBody(ctx, r, body, service, region, 0, signTime)
268268
}
269269

270270
// Presign signs AWS v4 requests with the provided body, service name, region
@@ -297,12 +297,12 @@ func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region strin
297297
// PUT/GET capabilities. If you would like to include the body's SHA256 in the
298298
// presigned request's signature you can set the "X-Amz-Content-Sha256"
299299
// HTTP header and that will be included in the request's signature.
300-
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
301-
return v4.signWithBody(r, body, service, region, exp, signTime)
300+
func (v4 Signer) Presign(ctx context.Context, r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
301+
return v4.signWithBody(ctx, r, body, service, region, exp, signTime)
302302
}
303303

304-
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
305-
ctx := &signingCtx{
304+
func (v4 Signer) signWithBody(ctx context.Context, r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
305+
signingCtx := &signingCtx{
306306
Request: r,
307307
Body: body,
308308
Query: r.URL.Query(),
@@ -315,31 +315,31 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
315315
unsignedPayload: v4.UnsignedPayload,
316316
}
317317

318-
for key := range ctx.Query {
319-
sort.Strings(ctx.Query[key])
318+
for key := range signingCtx.Query {
319+
sort.Strings(signingCtx.Query[key])
320320
}
321321

322-
if ctx.isRequestSigned() {
323-
ctx.Time = sdk.NowTime()
324-
ctx.handlePresignRemoval()
322+
if signingCtx.isRequestSigned() {
323+
signingCtx.Time = sdk.NowTime()
324+
signingCtx.handlePresignRemoval()
325325
}
326326

327327
var err error
328-
ctx.credValues, err = v4.Credentials.Retrieve(context.Background())
328+
signingCtx.credValues, err = v4.Credentials.Retrieve(ctx)
329329
if err != nil {
330330
return http.Header{}, err
331331
}
332332

333-
aws.SanitizeHostForHeader(ctx.Request)
334-
ctx.assignAmzQueryValues()
335-
if err := ctx.build(v4.DisableHeaderHoisting); err != nil {
333+
aws.SanitizeHostForHeader(signingCtx.Request)
334+
signingCtx.assignAmzQueryValues()
335+
if err := signingCtx.build(v4.DisableHeaderHoisting); err != nil {
336336
return nil, err
337337
}
338338

339339
// If the request is not presigned the body should be attached to it. This
340340
// prevents the confusion of wanting to send a signed request without
341341
// the body the request was signed for attached.
342-
if !(v4.DisableRequestBodyOverwrite || ctx.isPresign) {
342+
if !(v4.DisableRequestBodyOverwrite || signingCtx.isPresign) {
343343
var reader io.ReadCloser
344344
if body != nil {
345345
var ok bool
@@ -351,10 +351,10 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
351351
}
352352

353353
if v4.Debug.Matches(aws.LogDebugWithSigning) {
354-
v4.logSigningInfo(ctx)
354+
v4.logSigningInfo(signingCtx)
355355
}
356356

357-
return ctx.SignedHeaderVals, nil
357+
return signingCtx.SignedHeaderVals, nil
358358
}
359359

360360
func (ctx *signingCtx) handlePresignRemoval() {
@@ -455,7 +455,7 @@ func SignSDKRequest(req *aws.Request, opts ...func(*Signer)) {
455455
signingTime = req.LastSignedAt
456456
}
457457

458-
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
458+
signedHeaders, err := v4.signWithBody(req.Context(), req.HTTPRequest, req.GetBody(),
459459
name, region, req.ExpireTime, signingTime,
460460
)
461461
if err != nil {

aws/signer/v4/v4_test.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func TestPresignRequest(t *testing.T) {
125125
req, body := buildRequest("dynamodb", "us-east-1", "{}")
126126

127127
signer := buildSigner()
128-
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
128+
signer.Presign(context.Background(), req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
129129

130130
expectedDate := "19700101T000000Z"
131131
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
@@ -159,7 +159,7 @@ func TestPresignBodyWithArrayRequest(t *testing.T) {
159159
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
160160

161161
signer := buildSigner()
162-
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
162+
signer.Presign(context.Background(), req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
163163

164164
expectedDate := "19700101T000000Z"
165165
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
@@ -191,7 +191,7 @@ func TestPresignBodyWithArrayRequest(t *testing.T) {
191191
func TestSignRequest(t *testing.T) {
192192
req, body := buildRequest("dynamodb", "us-east-1", "{}")
193193
signer := buildSigner()
194-
signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
194+
signer.Sign(context.Background(), req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
195195

196196
expectedDate := "19700101T000000Z"
197197
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
@@ -208,7 +208,7 @@ func TestSignRequest(t *testing.T) {
208208
func TestSignUnseekableBody(t *testing.T) {
209209
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
210210
signer := buildSigner()
211-
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
211+
_, err := signer.Sign(context.Background(), req, body, "mock-service", "mock-region", time.Now())
212212
if err == nil {
213213
t.Fatalf("expect error signing request")
214214
}
@@ -224,7 +224,7 @@ func TestSignUnsignedPayloadUnseekableBody(t *testing.T) {
224224
signer := buildSigner()
225225
signer.UnsignedPayload = true
226226

227-
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
227+
_, err := signer.Sign(context.Background(), req, body, "mock-service", "mock-region", time.Now())
228228
if err != nil {
229229
t.Fatalf("expect no error, got %v", err)
230230
}
@@ -241,7 +241,7 @@ func TestSignPreComputedHashUnseekableBody(t *testing.T) {
241241
signer := buildSigner()
242242

243243
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
244-
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
244+
_, err := signer.Sign(context.Background(), req, body, "mock-service", "mock-region", time.Now())
245245
if err != nil {
246246
t.Fatalf("expect no error, got %v", err)
247247
}
@@ -255,7 +255,7 @@ func TestSignPreComputedHashUnseekableBody(t *testing.T) {
255255
func TestSignBodyS3(t *testing.T) {
256256
req, body := buildRequest("s3", "us-east-1", "hello")
257257
signer := buildSigner()
258-
signer.Sign(req, body, "s3", "us-east-1", time.Now())
258+
signer.Sign(context.Background(), req, body, "s3", "us-east-1", time.Now())
259259
hash := req.Header.Get("X-Amz-Content-Sha256")
260260
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
261261
t.Errorf("expect %v, got %v", e, a)
@@ -265,7 +265,7 @@ func TestSignBodyS3(t *testing.T) {
265265
func TestSignBodyGlacier(t *testing.T) {
266266
req, body := buildRequest("glacier", "us-east-1", "hello")
267267
signer := buildSigner()
268-
signer.Sign(req, body, "glacier", "us-east-1", time.Now())
268+
signer.Sign(context.Background(), req, body, "glacier", "us-east-1", time.Now())
269269
hash := req.Header.Get("X-Amz-Content-Sha256")
270270
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
271271
t.Errorf("expect %v, got %v", e, a)
@@ -275,7 +275,7 @@ func TestSignBodyGlacier(t *testing.T) {
275275
func TestPresign_SignedPayload(t *testing.T) {
276276
req, body := buildRequest("glacier", "us-east-1", "hello")
277277
signer := buildSigner()
278-
signer.Presign(req, body, "glacier", "us-east-1", 5*time.Minute, time.Now())
278+
signer.Presign(context.Background(), req, body, "glacier", "us-east-1", 5*time.Minute, time.Now())
279279
hash := req.Header.Get("X-Amz-Content-Sha256")
280280
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
281281
t.Errorf("expect %v, got %v", e, a)
@@ -286,7 +286,7 @@ func TestPresign_UnsignedPayload(t *testing.T) {
286286
req, body := buildRequest("service-name", "us-east-1", "hello")
287287
signer := buildSigner()
288288
signer.UnsignedPayload = true
289-
signer.Presign(req, body, "service-name", "us-east-1", 5*time.Minute, time.Now())
289+
signer.Presign(context.Background(), req, body, "service-name", "us-east-1", 5*time.Minute, time.Now())
290290
hash := req.Header.Get("X-Amz-Content-Sha256")
291291
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
292292
t.Errorf("expect %v, got %v", e, a)
@@ -296,7 +296,7 @@ func TestPresign_UnsignedPayload(t *testing.T) {
296296
func TestPresign_UnsignedPayload_S3(t *testing.T) {
297297
req, body := buildRequest("s3", "us-east-1", "hello")
298298
signer := buildSigner()
299-
signer.Presign(req, body, "s3", "us-east-1", 5*time.Minute, time.Now())
299+
signer.Presign(context.Background(), req, body, "s3", "us-east-1", 5*time.Minute, time.Now())
300300
if a := req.Header.Get("X-Amz-Content-Sha256"); len(a) != 0 {
301301
t.Errorf("expect no content sha256 got %v", a)
302302
}
@@ -306,7 +306,7 @@ func TestSignPrecomputedBodyChecksum(t *testing.T) {
306306
req, body := buildRequest("dynamodb", "us-east-1", "hello")
307307
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
308308
signer := buildSigner()
309-
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
309+
signer.Sign(context.Background(), req, body, "dynamodb", "us-east-1", time.Now())
310310
hash := req.Header.Get("X-Amz-Content-Sha256")
311311
if e, a := "PRECOMPUTED", hash; e != a {
312312
t.Errorf("expect %v, got %v", e, a)
@@ -620,7 +620,7 @@ func TestSignWithRequestBody(t *testing.T) {
620620

621621
req, err := http.NewRequest("POST", server.URL, nil)
622622

623-
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
623+
_, err = signer.Sign(context.Background(), req, bytes.NewReader(expectBody), "service", "region", time.Now())
624624
if err != nil {
625625
t.Errorf("expect not no error, got %v", err)
626626
}
@@ -654,7 +654,7 @@ func TestSignWithRequestBody_Overwrite(t *testing.T) {
654654

655655
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
656656

657-
_, err = signer.Sign(req, nil, "service", "region", time.Now())
657+
_, err = signer.Sign(context.Background(), req, nil, "service", "region", time.Now())
658658
req.ContentLength = 0
659659

660660
if err != nil {
@@ -698,7 +698,7 @@ func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
698698
s := NewSigner(creds)
699699
origBody := req.Body
700700

701-
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
701+
_, err := s.Sign(context.Background(), req, seekerBody, "dynamodb", "us-east-1", time.Now())
702702
if err != nil {
703703
t.Fatalf("expect no error, got %v", err)
704704
}
@@ -723,7 +723,7 @@ func TestSignWithBody_NoReplaceRequestBody(t *testing.T) {
723723

724724
origBody := req.Body
725725

726-
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
726+
_, err := s.Sign(context.Background(), req, seekerBody, "dynamodb", "us-east-1", time.Now())
727727
if err != nil {
728728
t.Fatalf("expect no error, got %v", err)
729729
}
@@ -757,15 +757,15 @@ func BenchmarkPresignRequest(b *testing.B) {
757757
signer := buildSigner()
758758
req, body := buildRequest("dynamodb", "us-east-1", "{}")
759759
for i := 0; i < b.N; i++ {
760-
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Now())
760+
signer.Presign(context.Background(), req, body, "dynamodb", "us-east-1", 300*time.Second, time.Now())
761761
}
762762
}
763763

764764
func BenchmarkSignRequest(b *testing.B) {
765765
signer := buildSigner()
766766
req, body := buildRequest("dynamodb", "us-east-1", "{}")
767767
for i := 0; i < b.N; i++ {
768-
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
768+
signer.Sign(context.Background(), req, body, "dynamodb", "us-east-1", time.Now())
769769
}
770770
}
771771

example/service/rds/rdsutils/authentication/iam_authentication.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
package main
44

55
import (
6+
"context"
67
"database/sql"
78
"fmt"
89
"os"
910

10-
"github.com/go-sql-driver/mysql"
11-
1211
"github.com/aws/aws-sdk-go-v2/aws/external"
12+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
1313
"github.com/aws/aws-sdk-go-v2/aws/stscreds"
1414
"github.com/aws/aws-sdk-go-v2/service/rds/rdsutils"
1515
"github.com/aws/aws-sdk-go-v2/service/sts"
16+
"github.com/go-sql-driver/mysql"
1617
)
1718

1819
// Usage ./iam_authentication <region> <db user> <db name> <endpoint to database> <iam arn>
@@ -35,8 +36,8 @@ func main() {
3536
cfg.Region = awsRegion
3637

3738
credProvider := stscreds.NewAssumeRoleProvider(sts.New(cfg), os.Args[5])
38-
39-
authToken, err := rdsutils.BuildAuthToken(dbEndpoint, awsRegion, dbUser, credProvider)
39+
signer := v4.NewSigner(credProvider)
40+
authToken, err := rdsutils.BuildAuthToken(context.Background(), dbEndpoint, awsRegion, dbUser, signer)
4041

4142
// Create the MySQL DNS string for the DB connection
4243
// user:password@protocol(endpoint)/dbname?<params>

service/rds/rdsutils/builder.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package rdsutils
22

33
import (
4+
"context"
45
"fmt"
56
"net/url"
67

7-
"github.com/aws/aws-sdk-go-v2/aws"
88
"github.com/aws/aws-sdk-go-v2/aws/awserr"
99
)
1010

@@ -26,24 +26,24 @@ var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection form
2626
// string with the provided parameters. params field is required to have
2727
// a tls specification and allowCleartextPasswords must be set to true.
2828
type ConnectionStringBuilder struct {
29-
dbName string
30-
endpoint string
31-
region string
32-
user string
33-
credProvider aws.CredentialsProvider
29+
dbName string
30+
endpoint string
31+
region string
32+
user string
33+
signer HTTPV4Signer
3434

3535
connectFormat ConnectionFormat
3636
params url.Values
3737
}
3838

3939
// NewConnectionStringBuilder will return an ConnectionStringBuilder
40-
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, credProvider aws.CredentialsProvider) ConnectionStringBuilder {
40+
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, signer HTTPV4Signer) ConnectionStringBuilder {
4141
return ConnectionStringBuilder{
42-
dbName: dbName,
43-
endpoint: endpoint,
44-
region: region,
45-
user: dbUser,
46-
credProvider: credProvider,
42+
dbName: dbName,
43+
endpoint: endpoint,
44+
region: region,
45+
user: dbUser,
46+
signer: signer,
4747
}
4848
}
4949

@@ -99,19 +99,19 @@ func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
9999
// to the desired database.
100100
//
101101
// Example:
102-
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, credProvider)
103-
// connectStr, err := b.WithTCPFormat().Build()
102+
// signer := v4.NewSigner(credsProvider)
103+
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, signer)
104+
// connectStr, err := b.WithTCPFormat().Build(ctx)
104105
// if err != nil {
105106
// panic(err)
106107
// }
107108
// const dbType = "mysql"
108109
// db, err := sql.Open(dbType, connectStr)
109-
func (b ConnectionStringBuilder) Build() (string, error) {
110+
func (b ConnectionStringBuilder) Build(ctx context.Context) (string, error) {
110111
if b.connectFormat == NoConnectionFormat {
111112
return "", ErrNoConnectionFormat
112113
}
113-
114-
authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.credProvider)
114+
authToken, err := BuildAuthToken(ctx, b.endpoint, b.region, b.user, b.signer)
115115
if err != nil {
116116
return "", err
117117
}

0 commit comments

Comments
 (0)