Skip to content

Commit 467aa8a

Browse files
skotambkarKotambkar
authored andcommitted
service/rds: Fix presign URL for same region (aws#331)
Fixes RDS no-autopresign URL for same region issue for aws-sdk-go-v2. Solves the issue by making sure that the presigned URLs are not created, when the source and destination regions are the same. Added and updated the tests accordingly. Fix aws#271
1 parent a357131 commit 467aa8a

File tree

2 files changed

+212
-66
lines changed

2 files changed

+212
-66
lines changed

service/rds/customizations.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ func copyDBSnapshotPresign(r *request.Request) {
4747
}
4848

4949
originParams.DestinationRegion = aws.String(r.Config.Region)
50+
// preSignedUrl is not required for instances in the same region.
51+
if *originParams.SourceRegion == *originParams.DestinationRegion {
52+
return
53+
}
5054
newParams := awsutil.CopyOf(r.Params).(*CopyDBSnapshotInput)
5155
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
5256
}
@@ -59,6 +63,10 @@ func createDBInstanceReadReplicaPresign(r *request.Request) {
5963
}
6064

6165
originParams.DestinationRegion = aws.String(r.Config.Region)
66+
// preSignedUrl is not required for instances in the same region.
67+
if *originParams.SourceRegion == *originParams.DestinationRegion {
68+
return
69+
}
6270
newParams := awsutil.CopyOf(r.Params).(*CreateDBInstanceReadReplicaInput)
6371
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
6472
}
@@ -71,6 +79,10 @@ func copyDBClusterSnapshotPresign(r *request.Request) {
7179
}
7280

7381
originParams.DestinationRegion = aws.String(r.Config.Region)
82+
// preSignedUrl is not required for instances in the same region.
83+
if *originParams.SourceRegion == *originParams.DestinationRegion {
84+
return
85+
}
7486
newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput)
7587
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
7688
}
@@ -83,6 +95,10 @@ func createDBClusterPresign(r *request.Request) {
8395
}
8496

8597
originParams.DestinationRegion = aws.String(r.Config.Region)
98+
// preSignedUrl is not required for instances in the same region.
99+
if *originParams.SourceRegion == *originParams.DestinationRegion {
100+
return
101+
}
86102
newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput)
87103
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
88104
}

service/rds/customizations_test.go

Lines changed: 196 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"io/ioutil"
66
"net/url"
77
"regexp"
8-
"strings"
98
"testing"
109
"time"
1110

@@ -16,8 +15,7 @@ import (
1615
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
1716
)
1817

19-
func TestPresignWithPresignNotSet(t *testing.T) {
20-
reqs := map[string]*request.Request{}
18+
func TestCopyDBSnapshotNoPanic(t *testing.T) {
2119

2220
cfg := unit.Config()
2321
cfg.Region = "us-west-2"
@@ -34,73 +32,184 @@ func TestPresignWithPresignNotSet(t *testing.T) {
3432
t.Errorf("expect no panic, got %v", p)
3533
}
3634

37-
reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
38-
SourceRegion: aws.String("us-west-1"),
39-
SourceDBSnapshotIdentifier: aws.String("foo"),
40-
TargetDBSnapshotIdentifier: aws.String("bar"),
41-
}).Request
42-
43-
reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
44-
SourceRegion: aws.String("us-west-1"),
45-
SourceDBInstanceIdentifier: aws.String("foo"),
46-
DBInstanceIdentifier: aws.String("bar"),
47-
}).Request
48-
49-
for op, req := range reqs {
50-
req.Sign()
51-
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
52-
q, _ := url.ParseQuery(string(b))
53-
54-
u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
55-
56-
exp := fmt.Sprintf(`^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=us-west-2.+`, op)
57-
if re, a := regexp.MustCompile(exp), u; !re.MatchString(a) {
58-
t.Errorf("expect %s to match %s", re, a)
59-
}
60-
}
6135
}
6236

63-
func TestPresignWithPresignSet(t *testing.T) {
64-
reqs := map[string]*request.Request{}
37+
func TestPresignCrossRegionRequest(t *testing.T) {
6538

6639
cfg := unit.Config()
6740
cfg.Region = "us-west-2"
41+
cfg.EndpointResolver = endpoints.NewDefaultResolver()
6842

6943
svc := New(cfg)
44+
const regexPattern= `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+`
7045

71-
f := func() {
72-
// Doesn't panic on nil input
73-
req := svc.CopyDBSnapshotRequest(nil)
74-
req.Sign()
75-
}
76-
if paniced, p := awstesting.DidPanic(f); paniced {
77-
t.Errorf("expect no panic, got %v", p)
78-
}
46+
cases := map[string]struct {
47+
Req *request.Request
48+
Assert func(*testing.T, string)
49+
}{
50+
opCopyDBSnapshot: {
51+
Req: func() *request.Request {
52+
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
53+
SourceRegion: aws.String("us-west-1"),
54+
SourceDBSnapshotIdentifier: aws.String("foo"),
55+
TargetDBSnapshotIdentifier: aws.String("bar"),
56+
})
57+
return req.Request
58+
}(),
59+
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
60+
opCopyDBSnapshot, cfg.Region)),
61+
},
7962

80-
reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
81-
SourceRegion: aws.String("us-west-1"),
82-
SourceDBSnapshotIdentifier: aws.String("foo"),
83-
TargetDBSnapshotIdentifier: aws.String("bar"),
84-
PreSignedUrl: aws.String("presignedURL"),
85-
}).Request
63+
opCreateDBInstanceReadReplica: {
64+
Req: func() *request.Request {
65+
req := svc.CreateDBInstanceReadReplicaRequest(
66+
&CreateDBInstanceReadReplicaInput{
67+
SourceRegion: aws.String("us-west-1"),
68+
SourceDBInstanceIdentifier: aws.String("foo"),
69+
DBInstanceIdentifier: aws.String("bar"),
70+
})
71+
return req.Request
72+
}(),
73+
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
74+
opCreateDBInstanceReadReplica, cfg.Region)),
75+
},
76+
opCopyDBClusterSnapshot: {
77+
Req: func() *request.Request {
78+
req := svc.CopyDBClusterSnapshotRequest(
79+
&CopyDBClusterSnapshotInput{
80+
SourceRegion: aws.String("us-west-1"),
81+
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
82+
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
83+
})
84+
return req.Request
85+
}(),
86+
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
87+
opCopyDBClusterSnapshot, cfg.Region)),
88+
},
89+
opCreateDBCluster: {
90+
Req: func() *request.Request {
91+
req := svc.CreateDBClusterRequest(
92+
&CreateDBClusterInput{
93+
SourceRegion: aws.String("us-west-1"),
94+
DBClusterIdentifier: aws.String("foo"),
95+
Engine: aws.String("bar"),
96+
})
97+
return req.Request
98+
}(),
99+
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
100+
opCreateDBCluster, cfg.Region)),
101+
},
102+
opCopyDBSnapshot + " same region": {
103+
Req: func() *request.Request {
104+
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
105+
SourceRegion: aws.String("us-west-2"),
106+
SourceDBSnapshotIdentifier: aws.String("foo"),
107+
TargetDBSnapshotIdentifier: aws.String("bar"),
108+
})
109+
return req.Request
110+
}(),
111+
Assert: assertAsEmpty(),
112+
},
113+
opCreateDBInstanceReadReplica + " same region": {
114+
Req: func() *request.Request {
115+
req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
116+
SourceRegion: aws.String("us-west-2"),
117+
SourceDBInstanceIdentifier: aws.String("foo"),
118+
DBInstanceIdentifier: aws.String("bar"),
119+
})
120+
return req.Request
121+
}(),
122+
Assert: assertAsEmpty(),
123+
},
124+
opCopyDBClusterSnapshot + " same region": {
125+
Req: func() *request.Request {
126+
req := svc.CopyDBClusterSnapshotRequest(
127+
&CopyDBClusterSnapshotInput{
128+
SourceRegion: aws.String("us-west-2"),
129+
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
130+
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
131+
})
132+
return req.Request
133+
}(),
134+
Assert: assertAsEmpty(),
135+
},
136+
opCreateDBCluster + " same region": {
137+
Req: func() *request.Request {
138+
req := svc.CreateDBClusterRequest(
139+
&CreateDBClusterInput{
140+
SourceRegion: aws.String("us-west-2"),
141+
DBClusterIdentifier: aws.String("foo"),
142+
Engine: aws.String("bar"),
143+
})
144+
return req.Request
145+
}(),
146+
Assert: assertAsEmpty(),
147+
},
148+
opCopyDBSnapshot + " presignURL set": {
149+
Req: func() *request.Request {
150+
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
151+
SourceRegion: aws.String("us-west-1"),
152+
SourceDBSnapshotIdentifier: aws.String("foo"),
153+
TargetDBSnapshotIdentifier: aws.String("bar"),
154+
PreSignedUrl: aws.String("mockPresignedURL"),
155+
})
156+
return req.Request
157+
}(),
158+
Assert: assertAsEqual("mockPresignedURL"),
159+
},
160+
opCreateDBInstanceReadReplica + " presignURL set": {
161+
Req: func() *request.Request {
162+
req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
163+
SourceRegion: aws.String("us-west-1"),
164+
SourceDBInstanceIdentifier: aws.String("foo"),
165+
DBInstanceIdentifier: aws.String("bar"),
166+
PreSignedUrl: aws.String("mockPresignedURL"),
167+
})
168+
return req.Request
169+
}(),
170+
Assert: assertAsEqual("mockPresignedURL"),
171+
},
172+
opCopyDBClusterSnapshot + " presignURL set": {
173+
Req: func() *request.Request {
174+
req := svc.CopyDBClusterSnapshotRequest(
175+
&CopyDBClusterSnapshotInput{
176+
SourceRegion: aws.String("us-west-1"),
177+
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
178+
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
179+
PreSignedUrl: aws.String("mockPresignedURL"),
180+
})
181+
return req.Request
182+
}(),
183+
Assert: assertAsEqual("mockPresignedURL"),
184+
},
185+
opCreateDBCluster + " presignURL set": {
186+
Req: func() *request.Request {
187+
req := svc.CreateDBClusterRequest(
188+
&CreateDBClusterInput{
189+
SourceRegion: aws.String("us-west-1"),
190+
DBClusterIdentifier: aws.String("foo"),
191+
Engine: aws.String("bar"),
192+
PreSignedUrl: aws.String("mockPresignedURL"),
193+
})
194+
return req.Request
195+
}(),
196+
Assert: assertAsEqual("mockPresignedURL"),
197+
},
198+
}
86199

87-
reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
88-
SourceRegion: aws.String("us-west-1"),
89-
SourceDBInstanceIdentifier: aws.String("foo"),
90-
DBInstanceIdentifier: aws.String("bar"),
91-
PreSignedUrl: aws.String("presignedURL"),
92-
}).Request
200+
for name, c := range cases {
201+
t.Run(name, func(t *testing.T) {
202+
if err := c.Req.Sign(); err != nil {
203+
t.Fatalf("expect no error, got %v", err)
204+
}
205+
b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body)
206+
q, _ := url.ParseQuery(string(b))
93207

94-
for _, req := range reqs {
95-
req.Sign()
208+
u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
96209

97-
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
98-
q, _ := url.ParseQuery(string(b))
210+
c.Assert(t, u)
99211

100-
u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
101-
if e, a := "presignedURL", u; !strings.Contains(a, e) {
102-
t.Errorf("expect %s to be in %s", e, a)
103-
}
212+
})
104213
}
105214
}
106215

@@ -112,15 +221,6 @@ func TestPresignWithSourceNotSet(t *testing.T) {
112221

113222
svc := New(cfg)
114223

115-
f := func() {
116-
// Doesn't panic on nil input
117-
req := svc.CopyDBSnapshotRequest(nil)
118-
req.Sign()
119-
}
120-
if paniced, p := awstesting.DidPanic(f); paniced {
121-
t.Errorf("expect no panic, got %v", p)
122-
}
123-
124224
reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
125225
SourceDBSnapshotIdentifier: aws.String("foo"),
126226
TargetDBSnapshotIdentifier: aws.String("bar"),
@@ -133,3 +233,33 @@ func TestPresignWithSourceNotSet(t *testing.T) {
133233
}
134234
}
135235
}
236+
237+
func assertAsRegexMatch(exp string) func(*testing.T, string) {
238+
return func(t *testing.T, v string) {
239+
t.Helper()
240+
241+
if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) {
242+
t.Errorf("expect %s to match %s", re, a)
243+
}
244+
}
245+
}
246+
247+
func assertAsEmpty() func(*testing.T, string) {
248+
return func(t *testing.T, v string) {
249+
t.Helper()
250+
251+
if len(v) != 0 {
252+
t.Errorf("expect empty, got %v", v)
253+
}
254+
}
255+
}
256+
257+
func assertAsEqual(expect string) func(*testing.T, string) {
258+
return func(t *testing.T, v string) {
259+
t.Helper()
260+
261+
if e, a := expect, v; e != a {
262+
t.Errorf("expect %v, got %v", e, a)
263+
}
264+
}
265+
}

0 commit comments

Comments
 (0)