Skip to content

Commit a0b3a0b

Browse files
committed
Infer and validate PodSets for known GVKs
1 parent f35d93f commit a0b3a0b

File tree

4 files changed

+297
-4
lines changed

4 files changed

+297
-4
lines changed

internal/webhook/appwrapper_fixtures_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ func pod(milliCPU int64) workloadv1beta2.AppWrapperComponent {
8686
}
8787
}
8888

89+
func podForInference(milliCPU int64) workloadv1beta2.AppWrapperComponent {
90+
yamlString := fmt.Sprintf(podYAML,
91+
randName("pod"),
92+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
93+
94+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
95+
Expect(err).NotTo(HaveOccurred())
96+
return workloadv1beta2.AppWrapperComponent{
97+
Template: runtime.RawExtension{Raw: jsonBytes},
98+
}
99+
}
100+
89101
const namespacedPodYAML = `
90102
apiVersion: v1
91103
kind: Pod
@@ -179,6 +191,19 @@ func deployment(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComp
179191
}
180192
}
181193

194+
func deploymentForInference(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
195+
yamlString := fmt.Sprintf(deploymentYAML,
196+
randName("deployment"),
197+
replicaCount,
198+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
199+
200+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
201+
Expect(err).NotTo(HaveOccurred())
202+
return workloadv1beta2.AppWrapperComponent{
203+
Template: runtime.RawExtension{Raw: jsonBytes},
204+
}
205+
}
206+
182207
const rayClusterYAML = `
183208
apiVersion: ray.io/v1
184209
kind: RayCluster
@@ -426,3 +451,87 @@ func jobSet(replicasWorker int, milliCPUWorker int64) workloadv1beta2.AppWrapper
426451
Template: runtime.RawExtension{Raw: jsonBytes},
427452
}
428453
}
454+
455+
const jobYAML = `
456+
apiVersion: batch/v1
457+
kind: Job
458+
metadata:
459+
name: %v
460+
spec:
461+
parallelism: %v
462+
completions: %v
463+
template:
464+
spec:
465+
restartPolicy: Never
466+
containers:
467+
- name: busybox
468+
image: quay.io/project-codeflare/busybox:1.36
469+
command: ["sh", "-c", "sleep 30"]
470+
resources:
471+
requests:
472+
cpu: %v`
473+
474+
func jobForInference(parallelism int, completions int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
475+
yamlString := fmt.Sprintf(jobYAML,
476+
randName("job"),
477+
parallelism,
478+
completions,
479+
resource.NewMilliQuantity(milliCPU, resource.DecimalSI))
480+
481+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
482+
Expect(err).NotTo(HaveOccurred())
483+
return workloadv1beta2.AppWrapperComponent{
484+
Template: runtime.RawExtension{Raw: jsonBytes},
485+
}
486+
}
487+
488+
const pytorchJobYAML = `
489+
apiVersion: "kubeflow.org/v1"
490+
kind: PyTorchJob
491+
metadata:
492+
name: %v
493+
spec:
494+
pytorchReplicaSpecs:
495+
Master:
496+
restartPolicy: OnFailure
497+
template:
498+
spec:
499+
containers:
500+
- name: pytorch
501+
image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
502+
command:
503+
- "python3"
504+
- "/opt/pytorch-mnist/mnist.py"
505+
- "--epochs=1"
506+
resources:
507+
requests:
508+
cpu: %v
509+
Worker:
510+
replicas: %v
511+
restartPolicy: OnFailure
512+
template:
513+
spec:
514+
containers:
515+
- name: pytorch
516+
image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
517+
command:
518+
- "python3"
519+
- "/opt/pytorch-mnist/mnist.py"
520+
- "--epochs=1"
521+
resources:
522+
requests:
523+
cpu: %v`
524+
525+
func pytorchJobForInference(masterMilliCPU int64, workerReplicas int, workerMilliCPU int64) workloadv1beta2.AppWrapperComponent {
526+
yamlString := fmt.Sprintf(pytorchJobYAML,
527+
randName("pytorch-job"),
528+
resource.NewMilliQuantity(masterMilliCPU, resource.DecimalSI),
529+
workerReplicas,
530+
resource.NewMilliQuantity(workerMilliCPU, resource.DecimalSI))
531+
532+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
533+
Expect(err).NotTo(HaveOccurred())
534+
return workloadv1beta2.AppWrapperComponent{
535+
Template: runtime.RawExtension{Raw: jsonBytes},
536+
}
537+
}

internal/webhook/appwrapper_webhook.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func (w *AppWrapperWebhook) Default(ctx context.Context, obj runtime.Object) err
6262
if w.Config.EnableKueueIntegrations {
6363
jobframework.ApplyDefaultForSuspend((*wlc.AppWrapper)(aw), w.Config.ManageJobsWithoutQueueName)
6464
}
65+
if err := inferPodSets(ctx, aw); err != nil {
66+
log.FromContext(ctx).Info("Error raised during podSet inference", "job", aw)
67+
return err
68+
}
6569
return nil
6670
}
6771

@@ -98,6 +102,30 @@ func (w *AppWrapperWebhook) ValidateDelete(context.Context, runtime.Object) (adm
98102
return nil, nil
99103
}
100104

105+
// inferPodSets infers the AppWrapper's PodSets
106+
func inferPodSets(_ context.Context, aw *workloadv1beta2.AppWrapper) error {
107+
components := aw.Spec.Components
108+
componentsPath := field.NewPath("spec").Child("components")
109+
for idx, component := range components {
110+
compPath := componentsPath.Index(idx)
111+
112+
// Automatically create elided PodSets for known GVKs
113+
if len(component.PodSets) == 0 {
114+
unstruct := &unstructured.Unstructured{}
115+
_, _, err := unstructured.UnstructuredJSONScheme.Decode(component.Template.Raw, nil, unstruct)
116+
if err != nil {
117+
return field.Invalid(compPath.Child("template"), component.Template, "failed to decode as JSON")
118+
}
119+
podSets, err := utils.InferPodSets(unstruct)
120+
if err != nil {
121+
return err
122+
}
123+
components[idx].PodSets = podSets
124+
}
125+
}
126+
return nil
127+
}
128+
101129
// rbacs required to enable SubjectAccessReview
102130
//+kubebuilder:rbac:groups=authorization.k8s.io,resources=subjectaccessreviews,verbs=create
103131
//+kubebuilder:rbac:groups=apiextensions.k8s.io,resources=customresourcedefinitions,verbs=list
@@ -182,9 +210,15 @@ func (w *AppWrapperWebhook) validateAppWrapperCreate(ctx context.Context, aw *wo
182210
}
183211
podSpecCount += 1
184212
}
213+
214+
// 5. Validate PodSets for known GVKs
215+
if err := utils.ValidatePodSets(unstruct, component.PodSets); err != nil {
216+
allErrors = append(allErrors, field.Invalid(podSetsPath, component.PodSets, err.Error()))
217+
}
218+
185219
}
186220

187-
// 5. Enforce Kueue limitation that 0 < podSpecCount <= 8
221+
// 6. Enforce Kueue limitation that 0 < podSpecCount <= 8
188222
if podSpecCount == 0 {
189223
allErrors = append(allErrors, field.Invalid(componentsPath, components, "components contains no podspecs"))
190224
}

internal/webhook/appwrapper_webhook_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,15 @@ var _ = Describe("AppWrapper Webhook Tests", func() {
190190
Expect(aw.Spec.Suspend).Should(BeTrue())
191191
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
192192
})
193+
194+
It("PodSets are inferred for known GVKs", func() {
195+
aw := toAppWrapper(pod(100), deploymentForInference(1, 100), podForInference(100),
196+
jobForInference(2, 4, 100), jobForInference(8, 4, 100), pytorchJobForInference(100, 4, 100))
197+
198+
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets for deployments and pods should be inferred")
199+
Expect(aw.Spec.Suspend).Should(BeTrue())
200+
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
201+
})
193202
})
194203

195204
})

pkg/utils/utils.go

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@ import (
2424
v1 "k8s.io/api/core/v1"
2525
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
2626
"k8s.io/apimachinery/pkg/runtime"
27+
"k8s.io/apimachinery/pkg/runtime/schema"
28+
"k8s.io/utils/ptr"
2729

2830
workloadv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2"
2931
)
3032

33+
const templateString = "template"
34+
3135
// GetPodTemplateSpec extracts a Kueue-compatible PodTemplateSpec at the given path within obj
3236
func GetPodTemplateSpec(obj *unstructured.Unstructured, path string) (*v1.PodTemplateSpec, error) {
3337
candidatePTS, err := GetRawTemplate(obj.UnstructuredContent(), path)
@@ -92,11 +96,11 @@ func GetRawTemplate(obj map[string]interface{}, path string) (map[string]interfa
9296

9397
// get the value found at the given path or an error if the path is invalid
9498
func getValueAtPath(obj map[string]interface{}, path string) (interface{}, error) {
95-
if !strings.HasPrefix(path, "template") {
99+
processed := templateString
100+
if !strings.HasPrefix(path, processed) {
96101
return nil, fmt.Errorf("first element of the path must be 'template'")
97102
}
98-
remaining := strings.TrimPrefix(path, "template")
99-
processed := "template"
103+
remaining := strings.TrimPrefix(path, processed)
100104
var cursor interface{} = obj
101105

102106
for remaining != "" {
@@ -167,3 +171,140 @@ func ExpectedPodCount(aw *workloadv1beta2.AppWrapper) int32 {
167171
}
168172
return expected
169173
}
174+
175+
// InferReplicas parses the value at the given path within obj as an int or return 1 or error
176+
func InferReplicas(obj map[string]interface{}, path string) (int32, error) {
177+
if path == "" {
178+
// no path specified, default to one replica
179+
return 1, nil
180+
}
181+
182+
// check obj is well formed
183+
index := strings.LastIndex(path, ".")
184+
if index >= 0 {
185+
var err error
186+
obj, err = GetRawTemplate(obj, path[:index])
187+
if err != nil {
188+
return 0, err
189+
}
190+
}
191+
192+
// check type and value
193+
switch v := obj[path[index+1:]].(type) {
194+
case nil:
195+
return 1, nil // default to 1
196+
case int:
197+
return int32(v), nil
198+
case int32:
199+
return v, nil
200+
case int64:
201+
return int32(v), nil
202+
default:
203+
return 0, fmt.Errorf("at path position '%v' non-int value %v", path, v)
204+
}
205+
}
206+
207+
// where to find a replica count and a PodTemplateSpec in a resource
208+
type resourceTemplate struct {
209+
path string // path to pod template spec
210+
replicas string // path to replica count
211+
}
212+
213+
// map from known GVKs to resource templates
214+
var templatesForGVK = map[schema.GroupVersionKind][]resourceTemplate{
215+
{Group: "", Version: "v1", Kind: "Pod"}: {{path: "template"}},
216+
{Group: "apps", Version: "v1", Kind: "Deployment"}: {{path: "template.spec.template", replicas: "template.spec.replicas"}},
217+
{Group: "apps", Version: "v1", Kind: "StatefulSet"}: {{path: "template.spec.template", replicas: "template.spec.replicas"}},
218+
}
219+
220+
// InferPodSets infers PodSets for known GVKs
221+
func InferPodSets(obj *unstructured.Unstructured) ([]workloadv1beta2.AppWrapperPodSet, error) {
222+
gvk := obj.GroupVersionKind()
223+
podSets := []workloadv1beta2.AppWrapperPodSet{}
224+
225+
switch gvk {
226+
case schema.GroupVersionKind{Group: "batch", Version: "v1", Kind: "Job"}:
227+
var replicas int32 = 1
228+
if parallelism, err := GetReplicas(obj, "template.spec.parallelism"); err == nil {
229+
replicas = parallelism
230+
}
231+
if completions, err := GetReplicas(obj, "template.spec.completions"); err == nil && completions < replicas {
232+
replicas = completions
233+
}
234+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: "template.spec.template"})
235+
236+
case schema.GroupVersionKind{Group: "kubeflow.org", Version: "v1", Kind: "PyTorchJob"}:
237+
for _, replicaType := range []string{"Master", "Worker"} {
238+
prefix := "template.spec.pytorchReplicaSpecs." + replicaType + "."
239+
// validate path to replica template
240+
if _, err := getValueAtPath(obj.UnstructuredContent(), prefix+"template"); err == nil {
241+
// infer replica count
242+
replicas, err := InferReplicas(obj.UnstructuredContent(), prefix+"replicas")
243+
if err != nil {
244+
return nil, err
245+
}
246+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: prefix + "template"})
247+
}
248+
}
249+
250+
default:
251+
for _, template := range templatesForGVK[gvk] {
252+
// validate path to template
253+
if _, err := getValueAtPath(obj.UnstructuredContent(), template.path); err == nil {
254+
replicas, err := InferReplicas(obj.UnstructuredContent(), template.replicas)
255+
// infer replica count
256+
if err != nil {
257+
return nil, err
258+
}
259+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: template.path})
260+
}
261+
}
262+
}
263+
264+
return podSets, nil
265+
}
266+
267+
// ValidatePodSets compares declared and inferred PodSets for known GVKs
268+
func ValidatePodSets(obj *unstructured.Unstructured, podSets []workloadv1beta2.AppWrapperPodSet) error {
269+
declared := map[string]workloadv1beta2.AppWrapperPodSet{}
270+
271+
// construct a map with declared PodSets and find duplicates
272+
for _, p := range podSets {
273+
if _, ok := declared[p.Path]; ok {
274+
return fmt.Errorf("duplicate PodSets with path '%v'", p.Path)
275+
}
276+
declared[p.Path] = p
277+
}
278+
279+
// infer PodSets
280+
inferred, err := InferPodSets(obj)
281+
if err != nil {
282+
return err
283+
}
284+
285+
// nothing inferred, nothing to validate
286+
if len(inferred) == 0 {
287+
return nil
288+
}
289+
290+
// compare PodSet counts
291+
if len(inferred) != len(declared) {
292+
return fmt.Errorf("PodSet count %v differs from expected count %v", len(declared), len(inferred))
293+
}
294+
295+
// match inferred PodSets to declared PodSets
296+
for _, ips := range inferred {
297+
dps, ok := declared[ips.Path]
298+
if !ok {
299+
return fmt.Errorf("PodSet with path '%v' is missing", ips.Path)
300+
}
301+
302+
ipr := ptr.Deref(ips.Replicas, 1)
303+
dpr := ptr.Deref(dps.Replicas, 1)
304+
if ipr != dpr {
305+
return fmt.Errorf("replica count %v differs from expected count %v for PodSet at path position '%v'", dpr, ipr, ips.Path)
306+
}
307+
}
308+
309+
return nil
310+
}

0 commit comments

Comments
 (0)