Skip to content

Commit 77ab2bc

Browse files
add sagemaker-hyperpod compute type to resolve its pods via VPC ENI
1 parent 0700e85 commit 77ab2bc

File tree

2 files changed

+212
-12
lines changed

2 files changed

+212
-12
lines changed

pkg/networking/pod_eni_info_resolver.go

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ const (
2626
// EC2:DescribeNetworkInterface supports up to 200 filters per call.
2727
describeNetworkInterfacesFiltersLimit = 200
2828

29-
labelEKSComputeType = "eks.amazonaws.com/compute-type"
29+
labelEKSComputeType = "eks.amazonaws.com/compute-type"
30+
labelSageMakerComputeType = "sagemaker.amazonaws.com/compute-type"
3031
)
3132

3233
// PodENIInfoResolver is responsible for resolve the AWS VPC ENI that supports pod network.
@@ -141,20 +142,20 @@ func (r *defaultPodENIInfoResolver) saveENIInfosToCache(pods []k8s.PodInfo, eniI
141142
}
142143

143144
func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
144-
podsOnEc2, podsOnFargate, err := r.classifyPodsByComputeType(ctx, pods)
145+
podsOnEc2, podsOnFargate, podsOnSageMakerHyperPod, err := r.classifyPodsByComputeType(ctx, pods)
145146
if err != nil {
146147
return nil, err
147148
}
148149
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
149150
if len(podsOnEc2) > 0 {
150-
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsOnEc2, false)
151+
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsOnEc2, false, false)
151152
if err != nil {
152153
return nil, err
153154
}
154155
eniInfoByPodKey = eniInfoByPodKeyEc2
155156
}
156157
if len(podsOnFargate) > 0 {
157-
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsOnFargate, true)
158+
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsOnFargate, true, false)
158159
if err != nil {
159160
return nil, err
160161
}
@@ -164,17 +165,28 @@ func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Con
164165
}
165166
}
166167
}
168+
if len(podsOnSageMakerHyperPod) > 0 {
169+
eniInfoByPodKeySageMakerHyperPod, err := r.resolveViaCascadedLookup(ctx, podsOnSageMakerHyperPod, false, true)
170+
if err != nil {
171+
return nil, err
172+
}
173+
if len(eniInfoByPodKeySageMakerHyperPod) > 0 {
174+
for podKey, eniInfo := range eniInfoByPodKeySageMakerHyperPod {
175+
eniInfoByPodKey[podKey] = eniInfo
176+
}
177+
}
178+
}
167179
return eniInfoByPodKey, nil
168180
}
169181

170-
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool) (map[types.NamespacedName]ENIInfo, error) {
182+
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool, isSageMakerHyperPodNode bool) (map[types.NamespacedName]ENIInfo, error) {
171183
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
172184
resolveFuncs := []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
173185
r.resolveViaPodENIAnnotation,
174186
r.resolveViaNodeENIs,
175187
// TODO, add support for kubenet CNI plugin(kops) by resolve via routeTable.
176188
}
177-
if isFargateNode {
189+
if isFargateNode || isSageMakerHyperPodNode {
178190
resolveFuncs = []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
179191
r.resolveViaVPCENIs,
180192
}
@@ -281,6 +293,7 @@ func (r *defaultPodENIInfoResolver) resolveViaNodeENIs(ctx context.Context, pods
281293

282294
// resolveViaVPCENIs tries to resolve pod ENI by matching podIP against ENIs in vpc.
283295
// with EKS fargate pods, podIP is supported by an ENI in vpc.
296+
// with SageMaker HyperPod pods, podIP is supported by the visible cross-account ENI in customer vpc.
284297
func (r *defaultPodENIInfoResolver) resolveViaVPCENIs(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
285298
podKeysByIP := make(map[string][]types.NamespacedName, len(pods))
286299
for _, pod := range pods {
@@ -388,33 +401,39 @@ func (r *defaultPodENIInfoResolver) isPodSupportedByNodeENI(pod k8s.PodInfo, nod
388401
return false
389402
}
390403

391-
// classifyPodsByComputeType classifies in to ec2 and fargate groups
392-
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, error) {
404+
// classifyPodsByComputeType classifies in to ec2, fargate and sagemaker-hyperpod groups
405+
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, []k8s.PodInfo, error) {
393406
podsOnFargate := make([]k8s.PodInfo, 0, len(pods))
394407
podsOnEc2 := make([]k8s.PodInfo, 0, len(pods))
408+
podsOnSageMakerHyperPod := make([]k8s.PodInfo, 0, len(pods))
395409
nodeNameByComputeType := make(map[string]string)
396410
for _, pod := range pods {
397411
if _, exists := nodeNameByComputeType[pod.NodeName]; exists {
398412
if nodeNameByComputeType[pod.NodeName] == "fargate" {
399413
podsOnFargate = append(podsOnFargate, pod)
414+
} else if nodeNameByComputeType[pod.NodeName] == "sagemaker-hyperpod" {
415+
podsOnSageMakerHyperPod = append(podsOnSageMakerHyperPod, pod)
400416
} else {
401417
podsOnEc2 = append(podsOnEc2, pod)
402418
}
403419
}
404420
nodeKey := types.NamespacedName{Name: pod.NodeName}
405421
node := &corev1.Node{}
406422
if err := r.k8sClient.Get(ctx, nodeKey, node); err != nil {
407-
return nil, nil, err
423+
return nil, nil, nil, err
408424
}
409425
if node.Labels[labelEKSComputeType] == "fargate" {
410426
podsOnFargate = append(podsOnFargate, pod)
411427
nodeNameByComputeType[pod.NodeName] = "fargate"
428+
} else if node.Labels[labelSageMakerComputeType] == "hyperpod" {
429+
podsOnSageMakerHyperPod = append(podsOnSageMakerHyperPod, pod)
430+
nodeNameByComputeType[pod.NodeName] = "sagemaker-hyperpod"
412431
} else {
413432
podsOnEc2 = append(podsOnEc2, pod)
414433
nodeNameByComputeType[pod.NodeName] = "ec2"
415434
}
416435
}
417-
return podsOnEc2, podsOnFargate, nil
436+
return podsOnEc2, podsOnFargate, podsOnSageMakerHyperPod, nil
418437
}
419438

420439
// computePodENIInfoCacheKey computes the cacheKey for pod's ENIInfo cache.

pkg/networking/pod_eni_info_resolver_test.go

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_EC2(t *testing.T) {
807807
describeNetworkInterfacesIPChunkSize: 2,
808808
}
809809

810-
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, false)
810+
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, false, false)
811811
if tt.wantErr != nil {
812812
assert.EqualError(t, err, tt.wantErr.Error())
813813
} else {
@@ -988,7 +988,188 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_Fargate(t *testing.
988988
describeNetworkInterfacesIPChunkSize: 2,
989989
}
990990

991-
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true)
991+
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true, false)
992+
if tt.wantErr != nil {
993+
assert.EqualError(t, err, tt.wantErr.Error())
994+
} else {
995+
assert.NoError(t, err)
996+
assert.Equal(t, tt.want, got)
997+
}
998+
})
999+
}
1000+
}
1001+
1002+
func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_SageMakerHyperPod(t *testing.T) {
1003+
hyperPodNodeA := &corev1.Node{
1004+
ObjectMeta: metav1.ObjectMeta{
1005+
Name: "hyperpod-i-04442beca624ba65b",
1006+
Labels: map[string]string{
1007+
"sagemaker.amazonaws.com/compute-type": "hyperpod",
1008+
},
1009+
},
1010+
Spec: corev1.NodeSpec{
1011+
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04442beca624ba65b",
1012+
},
1013+
}
1014+
hyperPodNodeB := &corev1.Node{
1015+
ObjectMeta: metav1.ObjectMeta{
1016+
Name: "hyperpod-i-04159267183583d03",
1017+
Labels: map[string]string{
1018+
"sagemaker.amazonaws.com/compute-type": "hyperpod",
1019+
},
1020+
},
1021+
Spec: corev1.NodeSpec{
1022+
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04159267183583d03",
1023+
},
1024+
}
1025+
type describeNetworkInterfacesAsListCall struct {
1026+
req *ec2sdk.DescribeNetworkInterfacesInput
1027+
resp []ec2types.NetworkInterface
1028+
err error
1029+
}
1030+
type fetchNodeInstancesCall struct {
1031+
nodes []*corev1.Node
1032+
nodeInstanceByNodeKey map[types.NamespacedName]*ec2types.Instance
1033+
err error
1034+
}
1035+
type env struct {
1036+
nodes []*corev1.Node
1037+
}
1038+
type fields struct {
1039+
describeNetworkInterfacesAsListCalls []describeNetworkInterfacesAsListCall
1040+
fetchNodeInstancesCalls []fetchNodeInstancesCall
1041+
}
1042+
type args struct {
1043+
pods []k8s.PodInfo
1044+
}
1045+
tests := []struct {
1046+
name string
1047+
env env
1048+
fields fields
1049+
args args
1050+
want map[types.NamespacedName]ENIInfo
1051+
wantErr error
1052+
}{
1053+
{
1054+
name: "all pod's ENI resolved via VPC's ENIs",
1055+
env: env{
1056+
nodes: []*corev1.Node{hyperPodNodeA, hyperPodNodeB},
1057+
},
1058+
fields: fields{
1059+
describeNetworkInterfacesAsListCalls: []describeNetworkInterfacesAsListCall{
1060+
{
1061+
req: &ec2sdk.DescribeNetworkInterfacesInput{
1062+
Filters: []ec2types.Filter{
1063+
{
1064+
Name: awssdk.String("vpc-id"),
1065+
Values: []string{"vpc-0d6d9ee10bd062dcc"},
1066+
},
1067+
{
1068+
Name: awssdk.String("addresses.private-ip-address"),
1069+
Values: []string{"192.168.128.151", "192.168.128.152"},
1070+
},
1071+
},
1072+
},
1073+
resp: []ec2types.NetworkInterface{
1074+
{
1075+
NetworkInterfaceId: awssdk.String("eni-c"),
1076+
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
1077+
{
1078+
PrivateIpAddress: awssdk.String("192.168.128.150"),
1079+
},
1080+
{
1081+
PrivateIpAddress: awssdk.String("192.168.128.151"),
1082+
},
1083+
},
1084+
Groups: []ec2types.GroupIdentifier{
1085+
{
1086+
GroupId: awssdk.String("sg-c-1"),
1087+
},
1088+
},
1089+
},
1090+
{
1091+
NetworkInterfaceId: awssdk.String("eni-d"),
1092+
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
1093+
{
1094+
PrivateIpAddress: awssdk.String("192.168.128.152"),
1095+
},
1096+
{
1097+
PrivateIpAddress: awssdk.String("192.168.128.153"),
1098+
},
1099+
},
1100+
Groups: []ec2types.GroupIdentifier{
1101+
{
1102+
GroupId: awssdk.String("sg-d-1"),
1103+
},
1104+
},
1105+
},
1106+
},
1107+
},
1108+
},
1109+
},
1110+
args: args{
1111+
pods: []k8s.PodInfo{
1112+
{
1113+
Key: types.NamespacedName{Namespace: "default", Name: "pod-1"},
1114+
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc01"),
1115+
NodeName: "hyperpod-i-04442beca624ba65b",
1116+
PodIP: "192.168.128.151",
1117+
},
1118+
{
1119+
Key: types.NamespacedName{Namespace: "default", Name: "pod-2"},
1120+
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc02"),
1121+
NodeName: "hyperpod-i-04159267183583d03",
1122+
PodIP: "192.168.128.152",
1123+
},
1124+
},
1125+
},
1126+
want: map[types.NamespacedName]ENIInfo{
1127+
types.NamespacedName{Namespace: "default", Name: "pod-1"}: {
1128+
NetworkInterfaceID: "eni-c",
1129+
SecurityGroups: []string{"sg-c-1"},
1130+
},
1131+
types.NamespacedName{Namespace: "default", Name: "pod-2"}: {
1132+
NetworkInterfaceID: "eni-d",
1133+
SecurityGroups: []string{"sg-d-1"},
1134+
},
1135+
},
1136+
},
1137+
}
1138+
for _, tt := range tests {
1139+
t.Run(tt.name, func(t *testing.T) {
1140+
ctrl := gomock.NewController(t)
1141+
defer ctrl.Finish()
1142+
1143+
ec2Client := services.NewMockEC2(ctrl)
1144+
for _, call := range tt.fields.describeNetworkInterfacesAsListCalls {
1145+
ec2Client.EXPECT().DescribeNetworkInterfacesAsList(gomock.Any(), call.req).Return(call.resp, call.err)
1146+
}
1147+
k8sSchema := runtime.NewScheme()
1148+
clientgoscheme.AddToScheme(k8sSchema)
1149+
k8sClient := fake.NewClientBuilder().WithScheme(k8sSchema).Build()
1150+
for _, node := range tt.env.nodes {
1151+
assert.NoError(t, k8sClient.Create(context.Background(), node.DeepCopy()))
1152+
}
1153+
nodeInfoProvider := NewMockNodeInfoProvider(ctrl)
1154+
for _, call := range tt.fields.fetchNodeInstancesCalls {
1155+
updatedNodes := make([]*corev1.Node, 0, len(call.nodes))
1156+
for _, node := range call.nodes {
1157+
updatedNode := &corev1.Node{}
1158+
assert.NoError(t, k8sClient.Get(context.Background(), k8s.NamespacedName(node), updatedNode))
1159+
updatedNodes = append(updatedNodes, updatedNode)
1160+
}
1161+
nodeInfoProvider.EXPECT().FetchNodeInstances(gomock.Any(), gomock.InAnyOrder(updatedNodes)).Return(call.nodeInstanceByNodeKey, call.err)
1162+
}
1163+
r := &defaultPodENIInfoResolver{
1164+
ec2Client: ec2Client,
1165+
k8sClient: k8sClient,
1166+
nodeInfoProvider: nodeInfoProvider,
1167+
vpcID: "vpc-0d6d9ee10bd062dcc",
1168+
logger: logr.New(&log.NullLogSink{}),
1169+
describeNetworkInterfacesIPChunkSize: 2,
1170+
}
1171+
1172+
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, false, true)
9921173
if tt.wantErr != nil {
9931174
assert.EqualError(t, err, tt.wantErr.Error())
9941175
} else {

0 commit comments

Comments
 (0)