Skip to content

Commit b87fa22

Browse files
authored
Infer pod sets for known GVKs (#108)
1 parent af4ef6d commit b87fa22

File tree

4 files changed

+400
-4
lines changed

4 files changed

+400
-4
lines changed

internal/webhook/appwrapper_fixtures_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ func pod(milliCPU int64) workloadv1beta2.AppWrapperComponent {
8686
}
8787
}
8888

89+
func podForInference(milliCPU int64) workloadv1beta2.AppWrapperComponent {
90+
yamlString := fmt.Sprintf(podYAML,
91+
randName("pod"),
92+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
93+
94+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
95+
Expect(err).NotTo(HaveOccurred())
96+
return workloadv1beta2.AppWrapperComponent{
97+
Template: runtime.RawExtension{Raw: jsonBytes},
98+
}
99+
}
100+
89101
const namespacedPodYAML = `
90102
apiVersion: v1
91103
kind: Pod
@@ -179,6 +191,19 @@ func deployment(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComp
179191
}
180192
}
181193

194+
func deploymentForInference(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
195+
yamlString := fmt.Sprintf(deploymentYAML,
196+
randName("deployment"),
197+
replicaCount,
198+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
199+
200+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
201+
Expect(err).NotTo(HaveOccurred())
202+
return workloadv1beta2.AppWrapperComponent{
203+
Template: runtime.RawExtension{Raw: jsonBytes},
204+
}
205+
}
206+
182207
const rayClusterYAML = `
183208
apiVersion: ray.io/v1
184209
kind: RayCluster
@@ -371,6 +396,20 @@ func rayCluster(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperCompo
371396
}
372397
}
373398

399+
func rayClusterForInference(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
400+
workerCPU := resource.NewMilliQuantity(milliCPU, resource.DecimalSI)
401+
yamlString := fmt.Sprintf(rayClusterYAML,
402+
randName("raycluster"),
403+
workerCount, workerCount, workerCount,
404+
workerCPU)
405+
406+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
407+
Expect(err).NotTo(HaveOccurred())
408+
return workloadv1beta2.AppWrapperComponent{
409+
Template: runtime.RawExtension{Raw: jsonBytes},
410+
}
411+
}
412+
374413
const jobSetYAML = `
375414
apiVersion: jobset.x-k8s.io/v1alpha2
376415
kind: JobSet
@@ -426,3 +465,128 @@ func jobSet(replicasWorker int, milliCPUWorker int64) workloadv1beta2.AppWrapper
426465
Template: runtime.RawExtension{Raw: jsonBytes},
427466
}
428467
}
468+
469+
const jobYAML = `
470+
apiVersion: batch/v1
471+
kind: Job
472+
metadata:
473+
name: %v
474+
spec:
475+
parallelism: %v
476+
completions: %v
477+
template:
478+
spec:
479+
restartPolicy: Never
480+
containers:
481+
- name: busybox
482+
image: quay.io/project-codeflare/busybox:1.36
483+
command: ["sh", "-c", "sleep 30"]
484+
resources:
485+
requests:
486+
cpu: %v`
487+
488+
func jobForInference(parallelism int, completions int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
489+
yamlString := fmt.Sprintf(jobYAML,
490+
randName("job"),
491+
parallelism,
492+
completions,
493+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
494+
495+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
496+
Expect(err).NotTo(HaveOccurred())
497+
return workloadv1beta2.AppWrapperComponent{
498+
Template: runtime.RawExtension{Raw: jsonBytes},
499+
}
500+
}
501+
502+
const pytorchJobYAML = `
503+
apiVersion: "kubeflow.org/v1"
504+
kind: PyTorchJob
505+
metadata:
506+
name: %v
507+
spec:
508+
pytorchReplicaSpecs:
509+
Master:
510+
restartPolicy: OnFailure
511+
template:
512+
spec:
513+
containers:
514+
- name: pytorch
515+
image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
516+
command:
517+
- "python3"
518+
- "/opt/pytorch-mnist/mnist.py"
519+
- "--epochs=1"
520+
resources:
521+
requests:
522+
cpu: %v
523+
Worker:
524+
replicas: %v
525+
restartPolicy: OnFailure
526+
template:
527+
spec:
528+
containers:
529+
- name: pytorch
530+
image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
531+
command:
532+
- "python3"
533+
- "/opt/pytorch-mnist/mnist.py"
534+
- "--epochs=1"
535+
resources:
536+
requests:
537+
cpu: %v`
538+
539+
func pytorchJobForInference(masterMilliCPU int64, workerReplicas int, workerMilliCPU int64) workloadv1beta2.AppWrapperComponent {
540+
yamlString := fmt.Sprintf(pytorchJobYAML,
541+
randName("pytorch-job"),
542+
resource.NewMilliQuantity(masterMilliCPU, resource.DecimalSI),
543+
workerReplicas,
544+
resource.NewMilliQuantity(workerMilliCPU, resource.DecimalSI))
545+
546+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
547+
Expect(err).NotTo(HaveOccurred())
548+
return workloadv1beta2.AppWrapperComponent{
549+
Template: runtime.RawExtension{Raw: jsonBytes},
550+
}
551+
}
552+
553+
const rayJobYAML = `
554+
apiVersion: ray.io/v1
555+
kind: RayJob
556+
metadata:
557+
name: %v
558+
spec:
559+
rayClusterSpec:
560+
headGroupSpec:
561+
template:
562+
spec:
563+
containers:
564+
- name: ray-head
565+
image: rayproject/ray:2.9.0
566+
resources:
567+
requests:
568+
cpu: 1
569+
workerGroupSpecs:
570+
- replicas: %v
571+
template:
572+
spec:
573+
containers:
574+
- name: ray-worker
575+
image: rayproject/ray:2.9.0
576+
resources:
577+
requests:
578+
cpu: %v
579+
`
580+
581+
func rayJobForInference(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
582+
yamlString := fmt.Sprintf(rayJobYAML,
583+
randName("rayjob"),
584+
workerCount,
585+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
586+
587+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
588+
Expect(err).NotTo(HaveOccurred())
589+
return workloadv1beta2.AppWrapperComponent{
590+
Template: runtime.RawExtension{Raw: jsonBytes},
591+
}
592+
}

internal/webhook/appwrapper_webhook.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func (w *AppWrapperWebhook) Default(ctx context.Context, obj runtime.Object) err
6262
if w.Config.EnableKueueIntegrations {
6363
jobframework.ApplyDefaultForSuspend((*wlc.AppWrapper)(aw), w.Config.ManageJobsWithoutQueueName)
6464
}
65+
if err := inferPodSets(ctx, aw); err != nil {
66+
log.FromContext(ctx).Info("Error raised during podSet inference", "job", aw)
67+
return err
68+
}
6569
return nil
6670
}
6771

@@ -98,6 +102,30 @@ func (w *AppWrapperWebhook) ValidateDelete(context.Context, runtime.Object) (adm
98102
return nil, nil
99103
}
100104

105+
// inferPodSets infers the AppWrapper's PodSets
106+
func inferPodSets(_ context.Context, aw *workloadv1beta2.AppWrapper) error {
107+
components := aw.Spec.Components
108+
componentsPath := field.NewPath("spec").Child("components")
109+
for idx, component := range components {
110+
compPath := componentsPath.Index(idx)
111+
112+
// Automatically create elided PodSets for known GVKs
113+
if len(component.PodSets) == 0 {
114+
unstruct := &unstructured.Unstructured{}
115+
_, _, err := unstructured.UnstructuredJSONScheme.Decode(component.Template.Raw, nil, unstruct)
116+
if err != nil {
117+
return field.Invalid(compPath.Child("template"), component.Template, "failed to decode as JSON")
118+
}
119+
podSets, err := utils.InferPodSets(unstruct)
120+
if err != nil {
121+
return err
122+
}
123+
components[idx].PodSets = podSets
124+
}
125+
}
126+
return nil
127+
}
128+
101129
// rbacs required to enable SubjectAccessReview
102130
//+kubebuilder:rbac:groups=authorization.k8s.io,resources=subjectaccessreviews,verbs=create
103131
//+kubebuilder:rbac:groups=apiextensions.k8s.io,resources=customresourcedefinitions,verbs=list
@@ -182,9 +210,15 @@ func (w *AppWrapperWebhook) validateAppWrapperCreate(ctx context.Context, aw *wo
182210
}
183211
podSpecCount += 1
184212
}
213+
214+
// 5. Validate PodSets for known GVKs
215+
if err := utils.ValidatePodSets(unstruct, component.PodSets); err != nil {
216+
allErrors = append(allErrors, field.Invalid(podSetsPath, component.PodSets, err.Error()))
217+
}
218+
185219
}
186220

187-
// 5. Enforce Kueue limitation that 0 < podSpecCount <= 8
221+
// 6. Enforce Kueue limitation that 0 < podSpecCount <= 8
188222
if podSpecCount == 0 {
189223
allErrors = append(allErrors, field.Invalid(componentsPath, components, "components contains no podspecs"))
190224
}

internal/webhook/appwrapper_webhook_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,25 @@ var _ = Describe("AppWrapper Webhook Tests", func() {
190190
Expect(aw.Spec.Suspend).Should(BeTrue())
191191
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
192192
})
193+
194+
Context("PodSets are inferred for known GVKs", func() {
195+
It("PodSets are inferred for common kinds", func() {
196+
aw := toAppWrapper(pod(100), deploymentForInference(1, 100), podForInference(100),
197+
jobForInference(2, 4, 100), jobForInference(8, 4, 100))
198+
199+
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred")
200+
Expect(aw.Spec.Suspend).Should(BeTrue())
201+
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
202+
})
203+
204+
It("PodSets are inferred for PyTorchJobs, RayClusters, and RayJobs", func() {
205+
aw := toAppWrapper(pytorchJobForInference(100, 4, 100), rayClusterForInference(7, 100), rayJobForInference(7, 100))
206+
207+
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred")
208+
Expect(aw.Spec.Suspend).Should(BeTrue())
209+
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
210+
})
211+
})
193212
})
194213

195214
})

0 commit comments

Comments
 (0)