@@ -24,10 +24,14 @@ import (
2424 v1 "k8s.io/api/core/v1"
2525 "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
2626 "k8s.io/apimachinery/pkg/runtime"
27+ "k8s.io/apimachinery/pkg/runtime/schema"
28+ "k8s.io/utils/ptr"
2729
2830 workloadv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2"
2931)
3032
33+ const templateString = "template"
34+
3135// GetPodTemplateSpec extracts a Kueue-compatible PodTemplateSpec at the given path within obj
3236func GetPodTemplateSpec (obj * unstructured.Unstructured , path string ) (* v1.PodTemplateSpec , error ) {
3337 candidatePTS , err := GetRawTemplate (obj .UnstructuredContent (), path )
@@ -92,11 +96,11 @@ func GetRawTemplate(obj map[string]interface{}, path string) (map[string]interfa
9296
9397// get the value found at the given path or an error if the path is invalid
9498func getValueAtPath (obj map [string ]interface {}, path string ) (interface {}, error ) {
95- if ! strings .HasPrefix (path , "template" ) {
99+ processed := templateString
100+ if ! strings .HasPrefix (path , processed ) {
96101 return nil , fmt .Errorf ("first element of the path must be 'template'" )
97102 }
98- remaining := strings .TrimPrefix (path , "template" )
99- processed := "template"
103+ remaining := strings .TrimPrefix (path , processed )
100104 var cursor interface {} = obj
101105
102106 for remaining != "" {
@@ -167,3 +171,140 @@ func ExpectedPodCount(aw *workloadv1beta2.AppWrapper) int32 {
167171 }
168172 return expected
169173}
174+
175+ // InferReplicas parses the value at the given path within obj as an int or return 1 or error
176+ func InferReplicas (obj map [string ]interface {}, path string ) (int32 , error ) {
177+ if path == "" {
178+ // no path specified, default to one replica
179+ return 1 , nil
180+ }
181+
182+ // check obj is well formed
183+ index := strings .LastIndex (path , "." )
184+ if index >= 0 {
185+ var err error
186+ obj , err = GetRawTemplate (obj , path [:index ])
187+ if err != nil {
188+ return 0 , err
189+ }
190+ }
191+
192+ // check type and value
193+ switch v := obj [path [index + 1 :]].(type ) {
194+ case nil :
195+ return 1 , nil // default to 1
196+ case int :
197+ return int32 (v ), nil
198+ case int32 :
199+ return v , nil
200+ case int64 :
201+ return int32 (v ), nil
202+ default :
203+ return 0 , fmt .Errorf ("at path position '%v' non-int value %v" , path , v )
204+ }
205+ }
206+
207+ // where to find a replica count and a PodTemplateSpec in a resource
208+ type resourceTemplate struct {
209+ path string // path to pod template spec
210+ replicas string // path to replica count
211+ }
212+
213+ // map from known GVKs to resource templates
214+ var templatesForGVK = map [schema.GroupVersionKind ][]resourceTemplate {
215+ {Group : "" , Version : "v1" , Kind : "Pod" }: {{path : "template" }},
216+ {Group : "apps" , Version : "v1" , Kind : "Deployment" }: {{path : "template.spec.template" , replicas : "template.spec.replicas" }},
217+ {Group : "apps" , Version : "v1" , Kind : "StatefulSet" }: {{path : "template.spec.template" , replicas : "template.spec.replicas" }},
218+ }
219+
220+ // InferPodSets infers PodSets for known GVKs
221+ func InferPodSets (obj * unstructured.Unstructured ) ([]workloadv1beta2.AppWrapperPodSet , error ) {
222+ gvk := obj .GroupVersionKind ()
223+ podSets := []workloadv1beta2.AppWrapperPodSet {}
224+
225+ switch gvk {
226+ case schema.GroupVersionKind {Group : "batch" , Version : "v1" , Kind : "Job" }:
227+ var replicas int32 = 1
228+ if parallelism , err := GetReplicas (obj , "template.spec.parallelism" ); err == nil {
229+ replicas = parallelism
230+ }
231+ if completions , err := GetReplicas (obj , "template.spec.completions" ); err == nil && completions < replicas {
232+ replicas = completions
233+ }
234+ podSets = append (podSets , workloadv1beta2.AppWrapperPodSet {Replicas : ptr .To (replicas ), Path : "template.spec.template" })
235+
236+ case schema.GroupVersionKind {Group : "kubeflow.org" , Version : "v1" , Kind : "PyTorchJob" }:
237+ for _ , replicaType := range []string {"Master" , "Worker" } {
238+ prefix := "template.spec.pytorchReplicaSpecs." + replicaType + "."
239+ // validate path to replica template
240+ if _ , err := getValueAtPath (obj .UnstructuredContent (), prefix + "template" ); err == nil {
241+ // infer replica count
242+ replicas , err := InferReplicas (obj .UnstructuredContent (), prefix + "replicas" )
243+ if err != nil {
244+ return nil , err
245+ }
246+ podSets = append (podSets , workloadv1beta2.AppWrapperPodSet {Replicas : ptr .To (replicas ), Path : prefix + "template" })
247+ }
248+ }
249+
250+ default :
251+ for _ , template := range templatesForGVK [gvk ] {
252+ // validate path to template
253+ if _ , err := getValueAtPath (obj .UnstructuredContent (), template .path ); err == nil {
254+ replicas , err := InferReplicas (obj .UnstructuredContent (), template .replicas )
255+ // infer replica count
256+ if err != nil {
257+ return nil , err
258+ }
259+ podSets = append (podSets , workloadv1beta2.AppWrapperPodSet {Replicas : ptr .To (replicas ), Path : template .path })
260+ }
261+ }
262+ }
263+
264+ return podSets , nil
265+ }
266+
267+ // ValidatePodSets compares declared and inferred PodSets for known GVKs
268+ func ValidatePodSets (obj * unstructured.Unstructured , podSets []workloadv1beta2.AppWrapperPodSet ) error {
269+ declared := map [string ]workloadv1beta2.AppWrapperPodSet {}
270+
271+ // construct a map with declared PodSets and find duplicates
272+ for _ , p := range podSets {
273+ if _ , ok := declared [p .Path ]; ok {
274+ return fmt .Errorf ("duplicate PodSets with path '%v'" , p .Path )
275+ }
276+ declared [p .Path ] = p
277+ }
278+
279+ // infer PodSets
280+ inferred , err := InferPodSets (obj )
281+ if err != nil {
282+ return err
283+ }
284+
285+ // nothing inferred, nothing to validate
286+ if len (inferred ) == 0 {
287+ return nil
288+ }
289+
290+ // compare PodSet counts
291+ if len (inferred ) != len (declared ) {
292+ return fmt .Errorf ("PodSet count %v differs from expected count %v" , len (declared ), len (inferred ))
293+ }
294+
295+ // match inferred PodSets to declared PodSets
296+ for _ , ips := range inferred {
297+ dps , ok := declared [ips .Path ]
298+ if ! ok {
299+ return fmt .Errorf ("PodSet with path '%v' is missing" , ips .Path )
300+ }
301+
302+ ipr := ptr .Deref (ips .Replicas , 1 )
303+ dpr := ptr .Deref (dps .Replicas , 1 )
304+ if ipr != dpr {
305+ return fmt .Errorf ("replica count %v differs from expected count %v for PodSet at path position '%v'" , dpr , ipr , ips .Path )
306+ }
307+ }
308+
309+ return nil
310+ }
0 commit comments