Skip to content

Commit 9300466

Browse files
committed
[bug fix] handle ram shared VPCs for cross account tgb
1 parent 879e715 commit 9300466

File tree

2 files changed

+307
-24
lines changed

2 files changed

+307
-24
lines changed

pkg/targetgroupbinding/resource_manager.go

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package targetgroupbinding
33
import (
44
"context"
55
"fmt"
6+
"k8s.io/apimachinery/pkg/util/cache"
67
"net/netip"
78
lbcmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/lbc"
9+
"sync"
810
"time"
911

1012
elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
@@ -28,7 +30,10 @@ import (
2830
"sigs.k8s.io/controller-runtime/pkg/client"
2931
)
3032

31-
const defaultRequeueDuration = 15 * time.Second
33+
const (
34+
defaultRequeueDuration = 15 * time.Second
35+
invalidVPCTTL = 60 * time.Minute
36+
)
3237

3338
// ResourceManager manages the TargetGroupBinding resource.
3439
type ResourceManager interface {
@@ -64,6 +69,9 @@ func NewDefaultResourceManager(k8sClient client.Client, elbv2Client services.ELB
6469
multiClusterManager: multiClusterManager,
6570
metricsCollector: metricsCollector,
6671

72+
invalidVpcCache: cache.NewExpiring(),
73+
invalidVpcCacheTTL: defaultTargetsCacheTTL,
74+
6775
requeueDuration: defaultRequeueDuration,
6876
}
6977
}
@@ -84,6 +92,10 @@ type defaultResourceManager struct {
8492
metricsCollector lbcmetrics.MetricCollector
8593
vpcID string
8694

95+
invalidVpcCache *cache.Expiring
96+
invalidVpcCacheTTL time.Duration
97+
invalidVpcCacheMutex sync.RWMutex
98+
8799
requeueDuration time.Duration
88100
}
89101

@@ -550,29 +562,10 @@ func (m *defaultResourceManager) registerPodEndpoints(ctx context.Context, tgb *
550562
"registering endpoints using the targetGroup's vpcID %s which is different from the cluster's vpcID %s", tgb.Spec.VpcID, m.vpcID))
551563
}
552564

553-
var overrideAzFn func(addr netip.Addr) bool
554-
if tgb.Spec.IamRoleArnToAssume != "" {
555-
// If we're interacting with another account, then we should always be setting "all" AZ to allow this
556-
// target to get registered by the ELB API.
557-
overrideAzFn = func(_ netip.Addr) bool {
558-
return true
559-
}
560-
} else {
561-
vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID)
562-
if err != nil {
563-
return err
564-
}
565-
var vpcRawCIDRs []string
566-
vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...)
567-
vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...)
568-
vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs)
569-
if err != nil {
570-
return err
571-
}
572-
// If the pod ip resides out of all the VPC CIDRs, then the only way to force the ELB API is to use "all" AZ.
573-
overrideAzFn = func(addr netip.Addr) bool {
574-
return !networking.IsIPWithinCIDRs(addr, vpcCIDRs)
575-
}
565+
overrideAzFn, err := m.generateOverrideAzFn(ctx, vpcID, tgb.Spec.IamRoleArnToAssume)
566+
567+
if err != nil {
568+
return err
576569
}
577570

578571
sdkTargets, err := m.prepareRegistrationCall(endpoints, overrideAzFn)
@@ -626,6 +619,66 @@ func (m *defaultResourceManager) updateTGBCheckPoint(ctx context.Context, tgb *e
626619
return nil
627620
}
628621

622+
func (m *defaultResourceManager) generateOverrideAzFn(ctx context.Context, vpcID string, assumeRole string) (func(addr netip.Addr) bool, error) {
623+
// Cross-Account is configured by assuming a role.
624+
usingCrossAccount := assumeRole != ""
625+
626+
// We need to cache the vpc response for the various assume roles.
627+
// There are two cases to consider when using assuming a role:
628+
// 1. Using a peered VPC connection to provide connectivity among accounts.
629+
// 2. Using RAM shared subnet(s) to provide connectivity among accounts.
630+
// We need to handle the case where the user is potentially using the same VPC in the peered context
631+
// as well as the RAM shared context.
632+
// Using peered VPC connection, we will always need to override the AZ.
633+
// Using a RAM shared subnet / VPC means that we follow the standard logic of checking the pod ip against the VPC CIDRs.
634+
635+
invalidVPCCacheKey := fmt.Sprintf("%s-%s", assumeRole, vpcID)
636+
637+
if usingCrossAccount {
638+
// Prevent spamming EC2 with requests.
639+
// We can use the cached result for this VPC ID given for the current assume role ARN
640+
m.invalidVpcCacheMutex.RLock()
641+
_, invalidVPC := m.invalidVpcCache.Get(invalidVPCCacheKey)
642+
m.invalidVpcCacheMutex.RUnlock()
643+
644+
// In this case, we already received that this VPC was invalid, we can shortcut the EC2 call and just override the AZ.
645+
if invalidVPC {
646+
return func(addr netip.Addr) bool {
647+
return true
648+
}, nil
649+
}
650+
}
651+
652+
vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID)
653+
if err != nil {
654+
// A VPC Not Found Error along with cross-account usage means that the VPC either, is not shared with the assume
655+
// role account OR this falls into case (1) from above where the VPC is just peered but not shared with RAM.
656+
// As we can't differentiate if RAM sharing wasn't set up correctly OR the VPC is set up via peering, we will
657+
// just default to assume that the VPC is peered but not shared.
658+
if isVPCNotFoundError(err) && usingCrossAccount {
659+
m.invalidVpcCacheMutex.Lock()
660+
m.invalidVpcCache.Set(invalidVPCCacheKey, true, m.invalidVpcCacheTTL)
661+
m.invalidVpcCacheMutex.Unlock()
662+
return func(addr netip.Addr) bool {
663+
return true
664+
}, nil
665+
}
666+
return nil, err
667+
}
668+
var vpcRawCIDRs []string
669+
vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...)
670+
vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...)
671+
vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs)
672+
if err != nil {
673+
return nil, err
674+
}
675+
// By getting here, we have a valid VPC for whatever credential was used. We return "true" in the function below
676+
// when the pod ip falls outside the VPCs configured CIDRs, other we return "false" to ensure that the "all" is NOT injected.
677+
return func(addr netip.Addr) bool {
678+
return !networking.IsIPWithinCIDRs(addr, vpcCIDRs)
679+
}, nil
680+
}
681+
629682
type podEndpointAndTargetPair struct {
630683
endpoint backend.PodEndpoint
631684
target TargetInfo
@@ -747,3 +800,12 @@ func isELBV2TargetGroupARNInvalidError(err error) bool {
747800
}
748801
return false
749802
}
803+
804+
func isVPCNotFoundError(err error) bool {
805+
var apiErr smithy.APIError
806+
if errors.As(err, &apiErr) {
807+
code := apiErr.ErrorCode()
808+
return code == "InvalidVpcID.NotFound"
809+
}
810+
return false
811+
}

pkg/targetgroupbinding/resource_manager_test.go

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@ package targetgroupbinding
22

33
import (
44
"context"
5+
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
56
elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
7+
"github.com/aws/aws-sdk-go/aws"
8+
"github.com/aws/smithy-go"
9+
"github.com/golang/mock/gomock"
10+
"k8s.io/apimachinery/pkg/util/cache"
11+
"net/netip"
612
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
713
lbcmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/lbc"
14+
"sigs.k8s.io/aws-load-balancer-controller/pkg/networking"
815
"testing"
916

1017
awssdk "github.com/aws/aws-sdk-go-v2/aws"
@@ -481,3 +488,217 @@ func Test_containsTargetsInInitialState(t *testing.T) {
481488
})
482489
}
483490
}
491+
492+
func Test_defaultResourceManager_GenerateOverrideAzFn(t *testing.T) {
493+
494+
vpcId := "foo"
495+
496+
type ipTestCase struct {
497+
ip netip.Addr
498+
result bool
499+
}
500+
501+
testCases := []struct {
502+
name string
503+
vpcInfoCalls int
504+
assumeRole string
505+
vpcInfo networking.VPCInfo
506+
vpcInfoError error
507+
ipTestCases []ipTestCase
508+
expectErr bool
509+
}{
510+
{
511+
name: "standard case ipv4",
512+
vpcInfoCalls: 1,
513+
vpcInfo: networking.VPCInfo{
514+
CidrBlockAssociationSet: []ec2types.VpcCidrBlockAssociation{
515+
{
516+
CidrBlock: aws.String("127.0.0.0/24"),
517+
CidrBlockState: &ec2types.VpcCidrBlockState{
518+
State: ec2types.VpcCidrBlockStateCodeAssociated,
519+
},
520+
},
521+
},
522+
},
523+
ipTestCases: []ipTestCase{
524+
{
525+
ip: netip.MustParseAddr("172.0.0.0"),
526+
result: true,
527+
},
528+
{
529+
ip: netip.MustParseAddr("127.0.0.1"),
530+
result: false,
531+
},
532+
{
533+
ip: netip.MustParseAddr("127.0.0.2"),
534+
result: false,
535+
},
536+
},
537+
},
538+
{
539+
name: "standard case ipv6",
540+
vpcInfoCalls: 1,
541+
vpcInfo: networking.VPCInfo{
542+
Ipv6CidrBlockAssociationSet: []ec2types.VpcIpv6CidrBlockAssociation{
543+
{
544+
Ipv6CidrBlock: aws.String("2001:db8::/32"),
545+
Ipv6CidrBlockState: &ec2types.VpcCidrBlockState{
546+
State: ec2types.VpcCidrBlockStateCodeAssociated,
547+
},
548+
},
549+
},
550+
},
551+
ipTestCases: []ipTestCase{
552+
{
553+
ip: netip.MustParseAddr("5001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
554+
result: true,
555+
},
556+
{
557+
ip: netip.MustParseAddr("2001:db8:0:0:0:0:0:0"),
558+
result: false,
559+
},
560+
{
561+
ip: netip.MustParseAddr("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
562+
result: false,
563+
},
564+
},
565+
},
566+
{
567+
name: "assume role case ram shared vpc ipv4",
568+
vpcInfoCalls: 1,
569+
assumeRole: "foo",
570+
vpcInfo: networking.VPCInfo{
571+
CidrBlockAssociationSet: []ec2types.VpcCidrBlockAssociation{
572+
{
573+
CidrBlock: aws.String("127.0.0.0/24"),
574+
CidrBlockState: &ec2types.VpcCidrBlockState{
575+
State: ec2types.VpcCidrBlockStateCodeAssociated,
576+
},
577+
},
578+
},
579+
},
580+
ipTestCases: []ipTestCase{
581+
{
582+
ip: netip.MustParseAddr("172.0.0.0"),
583+
result: true,
584+
},
585+
{
586+
ip: netip.MustParseAddr("127.0.0.1"),
587+
result: false,
588+
},
589+
{
590+
ip: netip.MustParseAddr("127.0.0.2"),
591+
result: false,
592+
},
593+
},
594+
},
595+
{
596+
name: "assume role ram shared vpc case ipv6",
597+
vpcInfoCalls: 1,
598+
assumeRole: "foo",
599+
vpcInfo: networking.VPCInfo{
600+
Ipv6CidrBlockAssociationSet: []ec2types.VpcIpv6CidrBlockAssociation{
601+
{
602+
Ipv6CidrBlock: aws.String("2001:db8::/32"),
603+
Ipv6CidrBlockState: &ec2types.VpcCidrBlockState{
604+
State: ec2types.VpcCidrBlockStateCodeAssociated,
605+
},
606+
},
607+
},
608+
},
609+
ipTestCases: []ipTestCase{
610+
{
611+
ip: netip.MustParseAddr("5001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
612+
result: true,
613+
},
614+
{
615+
ip: netip.MustParseAddr("2001:db8:0:0:0:0:0:0"),
616+
result: false,
617+
},
618+
{
619+
ip: netip.MustParseAddr("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
620+
result: false,
621+
},
622+
},
623+
},
624+
{
625+
name: "assume role case peered vpc ipv4",
626+
vpcInfoCalls: 1,
627+
assumeRole: "foo",
628+
vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""},
629+
ipTestCases: []ipTestCase{
630+
{
631+
ip: netip.MustParseAddr("172.0.0.0"),
632+
result: true,
633+
},
634+
{
635+
ip: netip.MustParseAddr("127.0.0.1"),
636+
result: true,
637+
},
638+
{
639+
ip: netip.MustParseAddr("127.0.0.2"),
640+
result: true,
641+
},
642+
},
643+
},
644+
{
645+
name: "assume role peered vpc case ipv6",
646+
vpcInfoCalls: 1,
647+
assumeRole: "foo",
648+
vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""},
649+
ipTestCases: []ipTestCase{
650+
{
651+
ip: netip.MustParseAddr("5001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
652+
result: true,
653+
},
654+
{
655+
ip: netip.MustParseAddr("2001:db8:0:0:0:0:0:0"),
656+
result: true,
657+
},
658+
{
659+
ip: netip.MustParseAddr("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
660+
result: true,
661+
},
662+
},
663+
},
664+
{
665+
name: "not found error from vpc info should be propagated when not using assume role",
666+
vpcInfoCalls: 1,
667+
vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""},
668+
expectErr: true,
669+
},
670+
{
671+
name: "assume role case peered vpc other error should get propagated",
672+
vpcInfoCalls: 1,
673+
assumeRole: "foo",
674+
vpcInfoError: &smithy.GenericAPIError{Code: "other error", Message: ""},
675+
expectErr: true,
676+
},
677+
}
678+
679+
for _, tc := range testCases {
680+
t.Run(tc.name, func(t *testing.T) {
681+
ctrl := gomock.NewController(t)
682+
defer ctrl.Finish()
683+
vpcInfoProvider := networking.NewMockVPCInfoProvider(ctrl)
684+
vpcInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vpcInfo, tc.vpcInfoError).Times(tc.vpcInfoCalls)
685+
m := &defaultResourceManager{
686+
logger: logr.New(&log.NullLogSink{}),
687+
invalidVpcCache: cache.NewExpiring(),
688+
vpcInfoProvider: vpcInfoProvider,
689+
}
690+
691+
returnedFn, err := m.generateOverrideAzFn(context.Background(), vpcId, tc.assumeRole)
692+
693+
if tc.expectErr {
694+
assert.Error(t, err)
695+
return
696+
}
697+
698+
assert.NoError(t, err)
699+
for _, iptc := range tc.ipTestCases {
700+
assert.Equal(t, iptc.result, returnedFn(iptc.ip), iptc.ip)
701+
}
702+
})
703+
}
704+
}

0 commit comments

Comments
 (0)