Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions pkg/networking/pod_eni_info_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ const (
// EC2:DescribeNetworkInterface supports up to 200 filters per call.
describeNetworkInterfacesFiltersLimit = 200

labelEKSComputeType = "eks.amazonaws.com/compute-type"
labelEKSComputeType = "eks.amazonaws.com/compute-type"
labelSageMakerComputeType = "sagemaker.amazonaws.com/compute-type"
)

// PodENIInfoResolver is responsible for resolve the AWS VPC ENI that supports pod network.
Expand Down Expand Up @@ -141,7 +142,7 @@ func (r *defaultPodENIInfoResolver) saveENIInfosToCache(pods []k8s.PodInfo, eniI
}

func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
podsOnEc2, podsOnFargate, err := r.classifyPodsByComputeType(ctx, pods)
podsOnEc2, podsOnFargate, podsOnSageMakerHyperPod, err := r.classifyPodsByComputeType(ctx, pods)
if err != nil {
return nil, err
}
Expand All @@ -164,17 +165,28 @@ func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Con
}
}
}
if len(podsOnSageMakerHyperPod) > 0 {
eniInfoByPodKeySageMakerHyperPod, err := r.resolveViaCascadedLookup(ctx, podsOnSageMakerHyperPod, true)
if err != nil {
return nil, err
}
if len(eniInfoByPodKeySageMakerHyperPod) > 0 {
for podKey, eniInfo := range eniInfoByPodKeySageMakerHyperPod {
eniInfoByPodKey[podKey] = eniInfo
}
}
}
return eniInfoByPodKey, nil
}

func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool) (map[types.NamespacedName]ENIInfo, error) {
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isNonEc2Pod bool) (map[types.NamespacedName]ENIInfo, error) {
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
resolveFuncs := []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
r.resolveViaPodENIAnnotation,
r.resolveViaNodeENIs,
// TODO, add support for kubenet CNI plugin(kops) by resolve via routeTable.
}
if isFargateNode {
if isNonEc2Pod {
resolveFuncs = []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
r.resolveViaVPCENIs,
}
Expand Down Expand Up @@ -281,6 +293,7 @@ func (r *defaultPodENIInfoResolver) resolveViaNodeENIs(ctx context.Context, pods

// resolveViaVPCENIs tries to resolve pod ENI by matching podIP against ENIs in vpc.
// with EKS fargate pods, podIP is supported by an ENI in vpc.
// with SageMaker HyperPod pods, podIP is supported by the visible cross-account ENI in customer vpc.
func (r *defaultPodENIInfoResolver) resolveViaVPCENIs(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
podKeysByIP := make(map[string][]types.NamespacedName, len(pods))
for _, pod := range pods {
Expand Down Expand Up @@ -388,33 +401,39 @@ func (r *defaultPodENIInfoResolver) isPodSupportedByNodeENI(pod k8s.PodInfo, nod
return false
}

// classifyPodsByComputeType classifies in to ec2 and fargate groups
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, error) {
// classifyPodsByComputeType classifies in to ec2, fargate and sagemaker-hyperpod groups
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, []k8s.PodInfo, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor this to be more type friendly? E.g. it's non-trivial to remember arg1 is ec2, arg2 is fargate, arg3 is hyperpod.

Can you introduce a new type that basically is

struct{
ec2Pods []k8s.PodInfo
fargatePods []k8s.PodInfo
hyperpodPod []k8s.PodInfo
}

this will make the code cleaner and as we add more compute types it would be more extensible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup. refactored in the last commit. thanks!

podsOnFargate := make([]k8s.PodInfo, 0, len(pods))
podsOnEc2 := make([]k8s.PodInfo, 0, len(pods))
podsOnSageMakerHyperPod := make([]k8s.PodInfo, 0, len(pods))
nodeNameByComputeType := make(map[string]string)
for _, pod := range pods {
if _, exists := nodeNameByComputeType[pod.NodeName]; exists {
if nodeNameByComputeType[pod.NodeName] == "fargate" {
podsOnFargate = append(podsOnFargate, pod)
} else if nodeNameByComputeType[pod.NodeName] == "sagemaker-hyperpod" {
podsOnSageMakerHyperPod = append(podsOnSageMakerHyperPod, pod)
} else {
podsOnEc2 = append(podsOnEc2, pod)
}
}
nodeKey := types.NamespacedName{Name: pod.NodeName}
node := &corev1.Node{}
if err := r.k8sClient.Get(ctx, nodeKey, node); err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if node.Labels[labelEKSComputeType] == "fargate" {
podsOnFargate = append(podsOnFargate, pod)
nodeNameByComputeType[pod.NodeName] = "fargate"
} else if node.Labels[labelSageMakerComputeType] == "hyperpod" {
podsOnSageMakerHyperPod = append(podsOnSageMakerHyperPod, pod)
nodeNameByComputeType[pod.NodeName] = "sagemaker-hyperpod"
} else {
podsOnEc2 = append(podsOnEc2, pod)
nodeNameByComputeType[pod.NodeName] = "ec2"
}
}
return podsOnEc2, podsOnFargate, nil
return podsOnEc2, podsOnFargate, podsOnSageMakerHyperPod, nil
}

// computePodENIInfoCacheKey computes the cacheKey for pod's ENIInfo cache.
Expand Down
181 changes: 181 additions & 0 deletions pkg/networking/pod_eni_info_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,187 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_Fargate(t *testing.
}
}

func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_SageMakerHyperPod(t *testing.T) {
hyperPodNodeA := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "hyperpod-i-04442beca624ba65b",
Labels: map[string]string{
"sagemaker.amazonaws.com/compute-type": "hyperpod",
},
},
Spec: corev1.NodeSpec{
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04442beca624ba65b",
},
}
hyperPodNodeB := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "hyperpod-i-04159267183583d03",
Labels: map[string]string{
"sagemaker.amazonaws.com/compute-type": "hyperpod",
},
},
Spec: corev1.NodeSpec{
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04159267183583d03",
},
}
type describeNetworkInterfacesAsListCall struct {
req *ec2sdk.DescribeNetworkInterfacesInput
resp []ec2types.NetworkInterface
err error
}
type fetchNodeInstancesCall struct {
nodes []*corev1.Node
nodeInstanceByNodeKey map[types.NamespacedName]*ec2types.Instance
err error
}
type env struct {
nodes []*corev1.Node
}
type fields struct {
describeNetworkInterfacesAsListCalls []describeNetworkInterfacesAsListCall
fetchNodeInstancesCalls []fetchNodeInstancesCall
}
type args struct {
pods []k8s.PodInfo
}
tests := []struct {
name string
env env
fields fields
args args
want map[types.NamespacedName]ENIInfo
wantErr error
}{
{
name: "all pod's ENI resolved via VPC's ENIs",
env: env{
nodes: []*corev1.Node{hyperPodNodeA, hyperPodNodeB},
},
fields: fields{
describeNetworkInterfacesAsListCalls: []describeNetworkInterfacesAsListCall{
{
req: &ec2sdk.DescribeNetworkInterfacesInput{
Filters: []ec2types.Filter{
{
Name: awssdk.String("vpc-id"),
Values: []string{"vpc-0d6d9ee10bd062dcc"},
},
{
Name: awssdk.String("addresses.private-ip-address"),
Values: []string{"192.168.128.151", "192.168.128.152"},
},
},
},
resp: []ec2types.NetworkInterface{
{
NetworkInterfaceId: awssdk.String("eni-c"),
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
{
PrivateIpAddress: awssdk.String("192.168.128.150"),
},
{
PrivateIpAddress: awssdk.String("192.168.128.151"),
},
},
Groups: []ec2types.GroupIdentifier{
{
GroupId: awssdk.String("sg-c-1"),
},
},
},
{
NetworkInterfaceId: awssdk.String("eni-d"),
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
{
PrivateIpAddress: awssdk.String("192.168.128.152"),
},
{
PrivateIpAddress: awssdk.String("192.168.128.153"),
},
},
Groups: []ec2types.GroupIdentifier{
{
GroupId: awssdk.String("sg-d-1"),
},
},
},
},
},
},
},
args: args{
pods: []k8s.PodInfo{
{
Key: types.NamespacedName{Namespace: "default", Name: "pod-1"},
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc01"),
NodeName: "hyperpod-i-04442beca624ba65b",
PodIP: "192.168.128.151",
},
{
Key: types.NamespacedName{Namespace: "default", Name: "pod-2"},
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc02"),
NodeName: "hyperpod-i-04159267183583d03",
PodIP: "192.168.128.152",
},
},
},
want: map[types.NamespacedName]ENIInfo{
types.NamespacedName{Namespace: "default", Name: "pod-1"}: {
NetworkInterfaceID: "eni-c",
SecurityGroups: []string{"sg-c-1"},
},
types.NamespacedName{Namespace: "default", Name: "pod-2"}: {
NetworkInterfaceID: "eni-d",
SecurityGroups: []string{"sg-d-1"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ec2Client := services.NewMockEC2(ctrl)
for _, call := range tt.fields.describeNetworkInterfacesAsListCalls {
ec2Client.EXPECT().DescribeNetworkInterfacesAsList(gomock.Any(), call.req).Return(call.resp, call.err)
}
k8sSchema := runtime.NewScheme()
clientgoscheme.AddToScheme(k8sSchema)
k8sClient := fake.NewClientBuilder().WithScheme(k8sSchema).Build()
for _, node := range tt.env.nodes {
assert.NoError(t, k8sClient.Create(context.Background(), node.DeepCopy()))
}
nodeInfoProvider := NewMockNodeInfoProvider(ctrl)
for _, call := range tt.fields.fetchNodeInstancesCalls {
updatedNodes := make([]*corev1.Node, 0, len(call.nodes))
for _, node := range call.nodes {
updatedNode := &corev1.Node{}
assert.NoError(t, k8sClient.Get(context.Background(), k8s.NamespacedName(node), updatedNode))
updatedNodes = append(updatedNodes, updatedNode)
}
nodeInfoProvider.EXPECT().FetchNodeInstances(gomock.Any(), gomock.InAnyOrder(updatedNodes)).Return(call.nodeInstanceByNodeKey, call.err)
}
r := &defaultPodENIInfoResolver{
ec2Client: ec2Client,
k8sClient: k8sClient,
nodeInfoProvider: nodeInfoProvider,
vpcID: "vpc-0d6d9ee10bd062dcc",
logger: logr.New(&log.NullLogSink{}),
describeNetworkInterfacesIPChunkSize: 2,
}

got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true)
if tt.wantErr != nil {
assert.EqualError(t, err, tt.wantErr.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

func Test_defaultPodENIInfoResolver_resolveViaPodENIAnnotation(t *testing.T) {
type describeNetworkInterfacesAsListCall struct {
req *ec2sdk.DescribeNetworkInterfacesInput
Expand Down