diff --git a/internal/webhook/appwrapper_fixtures_test.go b/internal/webhook/appwrapper_fixtures_test.go index c1942f8..1d5e370 100644 --- a/internal/webhook/appwrapper_fixtures_test.go +++ b/internal/webhook/appwrapper_fixtures_test.go @@ -86,6 +86,18 @@ func pod(milliCPU int64) workloadv1beta2.AppWrapperComponent { } } +func podForInference(milliCPU int64) workloadv1beta2.AppWrapperComponent { + yamlString := fmt.Sprintf(podYAML, + randName("pod"), + resource.NewMilliQuantity(milliCPU, resource.DecimalSI)) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} + const namespacedPodYAML = ` apiVersion: v1 kind: Pod @@ -179,6 +191,19 @@ func deployment(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComp } } +func deploymentForInference(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent { + yamlString := fmt.Sprintf(deploymentYAML, + randName("deployment"), + replicaCount, + resource.NewMilliQuantity(milliCPU, resource.DecimalSI)) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} + const rayClusterYAML = ` apiVersion: ray.io/v1 kind: RayCluster @@ -371,6 +396,20 @@ func rayCluster(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperCompo } } +func rayClusterForInference(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent { + workerCPU := resource.NewMilliQuantity(milliCPU, resource.DecimalSI) + yamlString := fmt.Sprintf(rayClusterYAML, + randName("raycluster"), + workerCount, workerCount, workerCount, + workerCPU) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} + const jobSetYAML = ` apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet @@ -426,3 +465,128 @@ func jobSet(replicasWorker int, milliCPUWorker int64) workloadv1beta2.AppWrapper Template: runtime.RawExtension{Raw: jsonBytes}, } } + +const jobYAML = ` +apiVersion: batch/v1 +kind: Job +metadata: + name: %v +spec: + parallelism: %v + completions: %v + template: + spec: + restartPolicy: Never + containers: + - name: busybox + image: quay.io/project-codeflare/busybox:1.36 + command: ["sh", "-c", "sleep 30"] + resources: + requests: + cpu: %v` + +func jobForInference(parallelism int, completions int, milliCPU int64) workloadv1beta2.AppWrapperComponent { + yamlString := fmt.Sprintf(jobYAML, + randName("job"), + parallelism, + completions, + resource.NewMilliQuantity(milliCPU, resource.DecimalSI)) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} + +const pytorchJobYAML = ` +apiVersion: "kubeflow.org/v1" +kind: PyTorchJob +metadata: + name: %v +spec: + pytorchReplicaSpecs: + Master: + restartPolicy: OnFailure + template: + spec: + containers: + - name: pytorch + image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1 + command: + - "python3" + - "/opt/pytorch-mnist/mnist.py" + - "--epochs=1" + resources: + requests: + cpu: %v + Worker: + replicas: %v + restartPolicy: OnFailure + template: + spec: + containers: + - name: pytorch + image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1 + command: + - "python3" + - "/opt/pytorch-mnist/mnist.py" + - "--epochs=1" + resources: + requests: + cpu: %v` + +func pytorchJobForInference(masterMilliCPU int64, workerReplicas int, workerMilliCPU int64) workloadv1beta2.AppWrapperComponent { + yamlString := fmt.Sprintf(pytorchJobYAML, + randName("pytorch-job"), + resource.NewMilliQuantity(masterMilliCPU, resource.DecimalSI), + workerReplicas, + resource.NewMilliQuantity(workerMilliCPU, resource.DecimalSI)) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} + +const rayJobYAML = ` +apiVersion: ray.io/v1 +kind: RayJob +metadata: + name: %v +spec: + rayClusterSpec: + headGroupSpec: + template: + spec: + containers: + - name: ray-head + image: rayproject/ray:2.9.0 + resources: + requests: + cpu: 1 + workerGroupSpecs: + - replicas: %v + template: + spec: + containers: + - name: ray-worker + image: rayproject/ray:2.9.0 + resources: + requests: + cpu: %v +` + +func rayJobForInference(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent { + yamlString := fmt.Sprintf(rayJobYAML, + randName("rayjob"), + workerCount, + resource.NewMilliQuantity(milliCPU, resource.DecimalSI)) + + jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString)) + Expect(err).NotTo(HaveOccurred()) + return workloadv1beta2.AppWrapperComponent{ + Template: runtime.RawExtension{Raw: jsonBytes}, + } +} diff --git a/internal/webhook/appwrapper_webhook.go b/internal/webhook/appwrapper_webhook.go index 64e5bb0..402c029 100644 --- a/internal/webhook/appwrapper_webhook.go +++ b/internal/webhook/appwrapper_webhook.go @@ -62,6 +62,10 @@ func (w *AppWrapperWebhook) Default(ctx context.Context, obj runtime.Object) err if w.Config.EnableKueueIntegrations { jobframework.ApplyDefaultForSuspend((*wlc.AppWrapper)(aw), w.Config.ManageJobsWithoutQueueName) } + if err := inferPodSets(ctx, aw); err != nil { + log.FromContext(ctx).Info("Error raised during podSet inference", "job", aw) + return err + } return nil } @@ -98,6 +102,30 @@ func (w *AppWrapperWebhook) ValidateDelete(context.Context, runtime.Object) (adm return nil, nil } +// inferPodSets infers the AppWrapper's PodSets +func inferPodSets(_ context.Context, aw *workloadv1beta2.AppWrapper) error { + components := aw.Spec.Components + componentsPath := field.NewPath("spec").Child("components") + for idx, component := range components { + compPath := componentsPath.Index(idx) + + // Automatically create elided PodSets for known GVKs + if len(component.PodSets) == 0 { + unstruct := &unstructured.Unstructured{} + _, _, err := unstructured.UnstructuredJSONScheme.Decode(component.Template.Raw, nil, unstruct) + if err != nil { + return field.Invalid(compPath.Child("template"), component.Template, "failed to decode as JSON") + } + podSets, err := utils.InferPodSets(unstruct) + if err != nil { + return err + } + components[idx].PodSets = podSets + } + } + return nil +} + // rbacs required to enable SubjectAccessReview //+kubebuilder:rbac:groups=authorization.k8s.io,resources=subjectaccessreviews,verbs=create //+kubebuilder:rbac:groups=apiextensions.k8s.io,resources=customresourcedefinitions,verbs=list @@ -182,9 +210,15 @@ func (w *AppWrapperWebhook) validateAppWrapperCreate(ctx context.Context, aw *wo } podSpecCount += 1 } + + // 5. Validate PodSets for known GVKs + if err := utils.ValidatePodSets(unstruct, component.PodSets); err != nil { + allErrors = append(allErrors, field.Invalid(podSetsPath, component.PodSets, err.Error())) + } + } - // 5. Enforce Kueue limitation that 0 < podSpecCount <= 8 + // 6. Enforce Kueue limitation that 0 < podSpecCount <= 8 if podSpecCount == 0 { allErrors = append(allErrors, field.Invalid(componentsPath, components, "components contains no podspecs")) } diff --git a/internal/webhook/appwrapper_webhook_test.go b/internal/webhook/appwrapper_webhook_test.go index 8894fb0..abfb109 100644 --- a/internal/webhook/appwrapper_webhook_test.go +++ b/internal/webhook/appwrapper_webhook_test.go @@ -190,6 +190,25 @@ var _ = Describe("AppWrapper Webhook Tests", func() { Expect(aw.Spec.Suspend).Should(BeTrue()) Expect(k8sClient.Delete(ctx, aw)).To(Succeed()) }) + + Context("PodSets are inferred for known GVKs", func() { + It("PodSets are inferred for common kinds", func() { + aw := toAppWrapper(pod(100), deploymentForInference(1, 100), podForInference(100), + jobForInference(2, 4, 100), jobForInference(8, 4, 100)) + + Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred") + Expect(aw.Spec.Suspend).Should(BeTrue()) + Expect(k8sClient.Delete(ctx, aw)).To(Succeed()) + }) + + It("PodSets are inferred for PyTorchJobs, RayClusters, and RayJobs", func() { + aw := toAppWrapper(pytorchJobForInference(100, 4, 100), rayClusterForInference(7, 100), rayJobForInference(7, 100)) + + Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred") + Expect(aw.Spec.Suspend).Should(BeTrue()) + Expect(k8sClient.Delete(ctx, aw)).To(Succeed()) + }) + }) }) }) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index d4779de..2fc2361 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -24,10 +24,14 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" workloadv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2" ) +const templateString = "template" + // GetPodTemplateSpec extracts a Kueue-compatible PodTemplateSpec at the given path within obj func GetPodTemplateSpec(obj *unstructured.Unstructured, path string) (*v1.PodTemplateSpec, error) { candidatePTS, err := GetRawTemplate(obj.UnstructuredContent(), path) @@ -92,11 +96,11 @@ func GetRawTemplate(obj map[string]interface{}, path string) (map[string]interfa // get the value found at the given path or an error if the path is invalid func getValueAtPath(obj map[string]interface{}, path string) (interface{}, error) { - if !strings.HasPrefix(path, "template") { + processed := templateString + if !strings.HasPrefix(path, processed) { return nil, fmt.Errorf("first element of the path must be 'template'") } - remaining := strings.TrimPrefix(path, "template") - processed := "template" + remaining := strings.TrimPrefix(path, processed) var cursor interface{} = obj for remaining != "" { @@ -167,3 +171,178 @@ func ExpectedPodCount(aw *workloadv1beta2.AppWrapper) int32 { } return expected } + +// inferReplicas parses the value at the given path within obj as an int or return 1 or error +func inferReplicas(obj map[string]interface{}, path string) (int32, error) { + if path == "" { + // no path specified, default to one replica + return 1, nil + } + + // check obj is well formed + index := strings.LastIndex(path, ".") + if index >= 0 { + var err error + obj, err = GetRawTemplate(obj, path[:index]) + if err != nil { + return 0, err + } + } + + // check type and value + switch v := obj[path[index+1:]].(type) { + case nil: + return 1, nil // default to 1 + case int: + return int32(v), nil + case int32: + return v, nil + case int64: + return int32(v), nil + default: + return 0, fmt.Errorf("at path position '%v' non-int value %v", path, v) + } +} + +// where to find a replica count and a PodTemplateSpec in a resource +type resourceTemplate struct { + path string // path to pod template spec + replicas string // path to replica count +} + +// map from known GVKs to resource templates +var templatesForGVK = map[schema.GroupVersionKind][]resourceTemplate{ + {Group: "", Version: "v1", Kind: "Pod"}: {{path: "template"}}, + {Group: "apps", Version: "v1", Kind: "Deployment"}: {{path: "template.spec.template", replicas: "template.spec.replicas"}}, + {Group: "apps", Version: "v1", Kind: "StatefulSet"}: {{path: "template.spec.template", replicas: "template.spec.replicas"}}, +} + +// inferPodSets infers PodSets for RayJobs and RayClusters +func inferRayPodSets(obj *unstructured.Unstructured, clusterSpecPrefix string) ([]workloadv1beta2.AppWrapperPodSet, error) { + podSets := []workloadv1beta2.AppWrapperPodSet{} + + podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(int32(1)), Path: clusterSpecPrefix + "headGroupSpec.template"}) + if workers, err := getValueAtPath(obj.UnstructuredContent(), clusterSpecPrefix+"workerGroupSpecs"); err == nil { + if workers, ok := workers.([]interface{}); ok { + for i := range workers { + workerGroupSpecPrefix := fmt.Sprintf(clusterSpecPrefix+"workerGroupSpecs[%v].", i) + // validate path to replica template + if _, err := getValueAtPath(obj.UnstructuredContent(), workerGroupSpecPrefix+templateString); err == nil { + // infer replica count + replicas, err := inferReplicas(obj.UnstructuredContent(), workerGroupSpecPrefix+"replicas") + if err != nil { + return nil, err + } + podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: workerGroupSpecPrefix + templateString}) + } + } + } + } + return podSets, nil +} + +// InferPodSets infers PodSets for known GVKs +func InferPodSets(obj *unstructured.Unstructured) ([]workloadv1beta2.AppWrapperPodSet, error) { + gvk := obj.GroupVersionKind() + podSets := []workloadv1beta2.AppWrapperPodSet{} + + switch gvk { + case schema.GroupVersionKind{Group: "batch", Version: "v1", Kind: "Job"}: + var replicas int32 = 1 + if parallelism, err := GetReplicas(obj, "template.spec.parallelism"); err == nil { + replicas = parallelism + } + if completions, err := GetReplicas(obj, "template.spec.completions"); err == nil && completions < replicas { + replicas = completions + } + podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: "template.spec.template"}) + + case schema.GroupVersionKind{Group: "kubeflow.org", Version: "v1", Kind: "PyTorchJob"}: + for _, replicaType := range []string{"Master", "Worker"} { + prefix := "template.spec.pytorchReplicaSpecs." + replicaType + "." + // validate path to replica template + if _, err := getValueAtPath(obj.UnstructuredContent(), prefix+templateString); err == nil { + // infer replica count + replicas, err := inferReplicas(obj.UnstructuredContent(), prefix+"replicas") + if err != nil { + return nil, err + } + podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: prefix + templateString}) + } + } + + case schema.GroupVersionKind{Group: "ray.io", Version: "v1", Kind: "RayCluster"}: + rayPodSets, err := inferRayPodSets(obj, "template.spec.") + if err != nil { + return nil, err + } + podSets = append(podSets, rayPodSets...) + + case schema.GroupVersionKind{Group: "ray.io", Version: "v1", Kind: "RayJob"}: + rayPodSets, err := inferRayPodSets(obj, "template.spec.rayClusterSpec.") + if err != nil { + return nil, err + } + podSets = append(podSets, rayPodSets...) + + default: + for _, template := range templatesForGVK[gvk] { + // validate path to template + if _, err := getValueAtPath(obj.UnstructuredContent(), template.path); err == nil { + replicas, err := inferReplicas(obj.UnstructuredContent(), template.replicas) + // infer replica count + if err != nil { + return nil, err + } + podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: template.path}) + } + } + } + + return podSets, nil +} + +// ValidatePodSets compares declared and inferred PodSets for known GVKs +func ValidatePodSets(obj *unstructured.Unstructured, podSets []workloadv1beta2.AppWrapperPodSet) error { + declared := map[string]workloadv1beta2.AppWrapperPodSet{} + + // construct a map with declared PodSets and find duplicates + for _, p := range podSets { + if _, ok := declared[p.Path]; ok { + return fmt.Errorf("duplicate PodSets with path '%v'", p.Path) + } + declared[p.Path] = p + } + + // infer PodSets + inferred, err := InferPodSets(obj) + if err != nil { + return err + } + + // nothing inferred, nothing to validate + if len(inferred) == 0 { + return nil + } + + // compare PodSet counts + if len(inferred) != len(declared) { + return fmt.Errorf("PodSet count %v differs from expected count %v", len(declared), len(inferred)) + } + + // match inferred PodSets to declared PodSets + for _, ips := range inferred { + dps, ok := declared[ips.Path] + if !ok { + return fmt.Errorf("PodSet with path '%v' is missing", ips.Path) + } + + ipr := ptr.Deref(ips.Replicas, 1) + dpr := ptr.Deref(dps.Replicas, 1) + if ipr != dpr { + return fmt.Errorf("replica count %v differs from expected count %v for PodSet at path position '%v'", dpr, ipr, ips.Path) + } + } + + return nil +}