Skip to content

Commit bd12807

Browse files
authored
syncing v1 rdsutils (#193)
1 parent 166597d commit bd12807

File tree

6 files changed

+336
-13
lines changed

6 files changed

+336
-13
lines changed

service/rds/rdsutils/builder.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package rdsutils
2+
3+
import (
4+
"fmt"
5+
"net/url"
6+
7+
"github.com/aws/aws-sdk-go-v2/aws"
8+
"github.com/aws/aws-sdk-go-v2/aws/awserr"
9+
)
10+
11+
// ConnectionFormat is the type of connection that will be
12+
// used to connect to the database
13+
type ConnectionFormat string
14+
15+
// ConnectionFormat enums
16+
const (
17+
NoConnectionFormat ConnectionFormat = ""
18+
TCPFormat ConnectionFormat = "tcp"
19+
)
20+
21+
// ErrNoConnectionFormat will be returned during build if no format had been
22+
// specified
23+
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)
24+
25+
// ConnectionStringBuilder is a builder that will construct a connection
26+
// string with the provided parameters. params field is required to have
27+
// a tls specification and allowCleartextPasswords must be set to true.
28+
type ConnectionStringBuilder struct {
29+
dbName string
30+
endpoint string
31+
region string
32+
user string
33+
credProvider aws.CredentialsProvider
34+
35+
connectFormat ConnectionFormat
36+
params url.Values
37+
}
38+
39+
// NewConnectionStringBuilder will return an ConnectionStringBuilder
40+
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, credProvider aws.CredentialsProvider) ConnectionStringBuilder {
41+
return ConnectionStringBuilder{
42+
dbName: dbName,
43+
endpoint: endpoint,
44+
region: region,
45+
user: dbUser,
46+
credProvider: credProvider,
47+
}
48+
}
49+
50+
// WithEndpoint will return a builder with the given endpoint
51+
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
52+
b.endpoint = endpoint
53+
return b
54+
}
55+
56+
// WithRegion will return a builder with the given region
57+
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
58+
b.region = region
59+
return b
60+
}
61+
62+
// WithUser will return a builder with the given user
63+
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
64+
b.user = user
65+
return b
66+
}
67+
68+
// WithDBName will return a builder with the given database name
69+
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
70+
b.dbName = dbName
71+
return b
72+
}
73+
74+
// WithParams will return a builder with the given params. The parameters
75+
// will be included in the connection query string
76+
//
77+
// Example:
78+
// v := url.Values{}
79+
// v.Add("tls", "rds")
80+
// b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, credProvider)
81+
// connectStr, err := b.WithParams(v).WithTCPFormat().Build()
82+
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
83+
b.params = params
84+
return b
85+
}
86+
87+
// WithFormat will return a builder with the given connection format
88+
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
89+
b.connectFormat = f
90+
return b
91+
}
92+
93+
// WithTCPFormat will set the format to TCP and return the modified builder
94+
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
95+
return b.WithFormat(TCPFormat)
96+
}
97+
98+
// Build will return a new connection string that can be used to open a connection
99+
// to the desired database.
100+
//
101+
// Example:
102+
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, credProvider)
103+
// connectStr, err := b.WithTCPFormat().Build()
104+
// if err != nil {
105+
// panic(err)
106+
// }
107+
// const dbType = "mysql"
108+
// db, err := sql.Open(dbType, connectStr)
109+
func (b ConnectionStringBuilder) Build() (string, error) {
110+
if b.connectFormat == NoConnectionFormat {
111+
return "", ErrNoConnectionFormat
112+
}
113+
114+
authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.credProvider)
115+
if err != nil {
116+
return "", err
117+
}
118+
119+
connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
120+
b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
121+
)
122+
123+
if len(b.params) > 0 {
124+
connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
125+
}
126+
return connectionStr, nil
127+
}

service/rds/rdsutils/builder_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package rdsutils_test
2+
3+
import (
4+
"net/url"
5+
"regexp"
6+
"testing"
7+
8+
"github.com/aws/aws-sdk-go-v2/aws"
9+
"github.com/aws/aws-sdk-go-v2/service/rds/rdsutils"
10+
)
11+
12+
func TestConnectionStringBuilder(t *testing.T) {
13+
cases := []struct {
14+
user string
15+
endpoint string
16+
region string
17+
dbName string
18+
values url.Values
19+
format rdsutils.ConnectionFormat
20+
credProvider aws.CredentialsProvider
21+
22+
expectedErr error
23+
expectedConnectRegex string
24+
}{
25+
{
26+
user: "foo",
27+
endpoint: "foo.bar",
28+
region: "region",
29+
dbName: "name",
30+
format: rdsutils.NoConnectionFormat,
31+
credProvider: aws.NewStaticCredentialsProvider("AKID", "SECRET", "SESSION"),
32+
expectedErr: rdsutils.ErrNoConnectionFormat,
33+
expectedConnectRegex: "",
34+
},
35+
{
36+
user: "foo",
37+
endpoint: "foo.bar",
38+
region: "region",
39+
dbName: "name",
40+
format: rdsutils.TCPFormat,
41+
credProvider: aws.NewStaticCredentialsProvider("AKID", "SECRET", "SESSION"),
42+
expectedConnectRegex: `^foo:foo.bar\?Action=connect\&DBUser=foo.*\@tcp\(foo.bar\)/name`,
43+
},
44+
}
45+
46+
for _, c := range cases {
47+
b := rdsutils.NewConnectionStringBuilder(c.endpoint, c.region, c.user, c.dbName, c.credProvider)
48+
connectStr, err := b.WithFormat(c.format).Build()
49+
50+
if e, a := c.expectedErr, err; e != a {
51+
t.Errorf("expected %v error, but received %v", e, a)
52+
}
53+
54+
if re, a := regexp.MustCompile(c.expectedConnectRegex), connectStr; !re.MatchString(a) {
55+
t.Errorf("expect %s to match %s", re, a)
56+
}
57+
}
58+
}

service/rds/rdsutils/connect.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@ import (
99
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
1010
)
1111

12-
// BuildAuthToken will return a authentication token for the database's connect
13-
// based on the RDS database endpoint, AWS region, IAM user or role, and AWS credentials.
12+
// BuildAuthToken will return an authorization token used as the password for a DB
13+
// connection.
1414
//
15-
// Endpoint consists of the hostname and port, IE hostname:port, of the RDS database.
16-
// Region is the AWS region the RDS database is in and where the authentication token
17-
// will be generated for. DbUser is the IAM user or role the request will be authenticated
18-
// for. The creds is the AWS credentials the authentication token is signed with.
19-
//
20-
// An error is returned if the authentication token is unable to be signed with
21-
// the credentials, or the endpoint is not a valid URL.
15+
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
16+
// * region - Region is the location of where the DB is
17+
// * dbUser - User account within the database to sign in with
18+
// * creds - Credentials to be signed with
2219
//
2320
// The following example shows how to use BuildAuthToken to create an authentication
2421
// token for connecting to a MySQL database in RDS.
@@ -27,12 +24,12 @@ import (
2724
//
2825
// // Create the MySQL DNS string for the DB connection
2926
// // user:password@protocol(endpoint)/dbname?<params>
30-
// dnsStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=true",
27+
// connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds",
3128
// dbUser, authToken, dbEndpoint, dbName,
3229
// )
3330
//
3431
// // Use db to perform SQL operations on database
35-
// db, err := sql.Open("mysql", dnsStr)
32+
// db, err := sql.Open("mysql", connectStr)
3633
//
3734
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
3835
// for more information on using IAM database authentication with RDS.

service/rds/rdsutils/connect_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"regexp"
55
"testing"
66

7-
credentials "github.com/aws/aws-sdk-go-v2/aws"
7+
"github.com/aws/aws-sdk-go-v2/aws"
88
"github.com/aws/aws-sdk-go-v2/service/rds/rdsutils"
99
)
1010

@@ -30,7 +30,7 @@ func TestBuildAuthToken(t *testing.T) {
3030
}
3131

3232
for _, c := range cases {
33-
provider := credentials.NewStaticCredentialsProvider("AKID", "SECRET", "SESSION")
33+
provider := aws.NewStaticCredentialsProvider("AKID", "SECRET", "SESSION")
3434
url, err := rdsutils.BuildAuthToken(c.endpoint, c.region, c.user, provider)
3535
if err != nil {
3636
t.Errorf("expect no error, got %v", err)

service/rds/rdsutils/doc.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Package rdsutils is used to generate authentication tokens used to
2+
// connect to a givent Amazon Relational Database Service (RDS) database.
3+
//
4+
// Before using the authentication please visit the docs here to ensure
5+
// the database has the proper policies to allow for IAM token authentication.
6+
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html#UsingWithRDS.IAMDBAuth.Availability
7+
//
8+
// When building the connection string, there are two required parameters that are needed to be set on the query.
9+
// * tls
10+
// * allowCleartextPasswords must be set to true
11+
//
12+
// Example creating a basic auth token with the builder:
13+
// v := url.Values{}
14+
// v.Add("tls", "tls_profile_name")
15+
// v.Add("allowCleartextPasswords", "true")
16+
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, credProvider)
17+
// connectStr, err := b.WithTCPFormat().WithParams(v).Build()
18+
package rdsutils

service/rds/rdsutils/example_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// +build example,exclude
2+
3+
package rdsutils_test
4+
5+
import (
6+
"crypto/tls"
7+
"crypto/x509"
8+
"database/sql"
9+
"flag"
10+
"fmt"
11+
"io/ioutil"
12+
"net/http"
13+
"net/url"
14+
"os"
15+
16+
"github.com/go-sql-driver/mysql"
17+
18+
"github.com/aws/aws-sdk-go-v2/aws/external"
19+
"github.com/aws/aws-sdk-go-v2/aws/stscreds"
20+
"github.com/aws/aws-sdk-go-v2/service/rds/rdsutils"
21+
"github.com/aws/aws-sdk-go-v2/service/sts"
22+
)
23+
24+
// ExampleConnectionStringBuilder contains usage of assuming a role and using
25+
// that to build the auth token.
26+
// Usage:
27+
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
28+
func ExampleConnectionStringBuilder() {
29+
userPtr := flag.String("user", "", "user of the credentials")
30+
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
31+
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
32+
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
33+
portPtr := flag.Int("port", 3306, "DB port to be connected to")
34+
tablePtr := flag.String("table", "test_table", "DB table to query against")
35+
dbNamePtr := flag.String("dbname", "", "DB name to query against")
36+
flag.Parse()
37+
38+
// Check required flags. Will exit with status code 1 if
39+
// required field isn't set.
40+
if err := requiredFlags(
41+
userPtr,
42+
regionPtr,
43+
roleArnPtr,
44+
endpointPtr,
45+
portPtr,
46+
dbNamePtr,
47+
); err != nil {
48+
fmt.Printf("Error: %v\n\n", err)
49+
flag.PrintDefaults()
50+
os.Exit(1)
51+
}
52+
53+
err := registerRDSMysqlCerts(http.DefaultClient)
54+
if err != nil {
55+
panic(err)
56+
}
57+
58+
cfg, err := external.LoadDefaultAWSConfig()
59+
cfg.Region = *regionPtr
60+
61+
stsSvc := sts.New(cfg)
62+
provider := stscreds.NewAssumeRoleProvider(stsSvc, *roleArnPtr)
63+
64+
v := url.Values{}
65+
// required fields for DB connection
66+
v.Add("tls", "rds")
67+
v.Add("allowCleartextPasswords", "true")
68+
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)
69+
70+
b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, provider)
71+
connectStr, err := b.WithTCPFormat().WithParams(v).Build()
72+
73+
const dbType = "mysql"
74+
db, err := sql.Open(dbType, connectStr)
75+
// if an error is encountered here, then most likely security groups are incorrect
76+
// in the database.
77+
if err != nil {
78+
panic(fmt.Errorf("failed to open connection to the database"))
79+
}
80+
81+
rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
82+
if err != nil {
83+
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
84+
}
85+
86+
for rows.Next() {
87+
columns, err := rows.Columns()
88+
if err != nil {
89+
panic(fmt.Errorf("failed to read columns from row: %v", err))
90+
}
91+
92+
fmt.Printf("rows colums:\n%d\n", len(columns))
93+
}
94+
}
95+
96+
func requiredFlags(flags ...interface{}) error {
97+
for _, f := range flags {
98+
switch f.(type) {
99+
case nil:
100+
return fmt.Errorf("one or more required flags were not set")
101+
}
102+
}
103+
return nil
104+
}
105+
106+
func registerRDSMysqlCerts(c *http.Client) error {
107+
resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem")
108+
if err != nil {
109+
return err
110+
}
111+
112+
pem, err := ioutil.ReadAll(resp.Body)
113+
if err != nil {
114+
return err
115+
}
116+
117+
rootCertPool := x509.NewCertPool()
118+
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
119+
return fmt.Errorf("failed to append cert to cert pool!")
120+
}
121+
122+
return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true})
123+
}

0 commit comments

Comments
 (0)