@@ -86,6 +86,18 @@ func pod(milliCPU int64) workloadv1beta2.AppWrapperComponent {
8686 }
8787}
8888
89+ func podForInference (milliCPU int64 ) workloadv1beta2.AppWrapperComponent {
90+ yamlString := fmt .Sprintf (podYAML ,
91+ randName ("pod" ),
92+ resource .NewMilliQuantity (milliCPU , resource .DecimalSI ))
93+
94+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
95+ Expect (err ).NotTo (HaveOccurred ())
96+ return workloadv1beta2.AppWrapperComponent {
97+ Template : runtime.RawExtension {Raw : jsonBytes },
98+ }
99+ }
100+
89101const namespacedPodYAML = `
90102apiVersion: v1
91103kind: Pod
@@ -179,6 +191,19 @@ func deployment(replicaCount int, milliCPU int64) workloadv1beta2.AppWrapperComp
179191 }
180192}
181193
194+ func deploymentForInference (replicaCount int , milliCPU int64 ) workloadv1beta2.AppWrapperComponent {
195+ yamlString := fmt .Sprintf (deploymentYAML ,
196+ randName ("deployment" ),
197+ replicaCount ,
198+ resource .NewMilliQuantity (milliCPU , resource .DecimalSI ))
199+
200+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
201+ Expect (err ).NotTo (HaveOccurred ())
202+ return workloadv1beta2.AppWrapperComponent {
203+ Template : runtime.RawExtension {Raw : jsonBytes },
204+ }
205+ }
206+
182207const rayClusterYAML = `
183208apiVersion: ray.io/v1
184209kind: RayCluster
@@ -371,6 +396,20 @@ func rayCluster(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperCompo
371396 }
372397}
373398
399+ func rayClusterForInference (workerCount int , milliCPU int64 ) workloadv1beta2.AppWrapperComponent {
400+ workerCPU := resource .NewMilliQuantity (milliCPU , resource .DecimalSI )
401+ yamlString := fmt .Sprintf (rayClusterYAML ,
402+ randName ("raycluster" ),
403+ workerCount , workerCount , workerCount ,
404+ workerCPU )
405+
406+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
407+ Expect (err ).NotTo (HaveOccurred ())
408+ return workloadv1beta2.AppWrapperComponent {
409+ Template : runtime.RawExtension {Raw : jsonBytes },
410+ }
411+ }
412+
374413const jobSetYAML = `
375414apiVersion: jobset.x-k8s.io/v1alpha2
376415kind: JobSet
@@ -426,3 +465,128 @@ func jobSet(replicasWorker int, milliCPUWorker int64) workloadv1beta2.AppWrapper
426465 Template : runtime.RawExtension {Raw : jsonBytes },
427466 }
428467}
468+
469+ const jobYAML = `
470+ apiVersion: batch/v1
471+ kind: Job
472+ metadata:
473+ name: %v
474+ spec:
475+ parallelism: %v
476+ completions: %v
477+ template:
478+ spec:
479+ restartPolicy: Never
480+ containers:
481+ - name: busybox
482+ image: quay.io/project-codeflare/busybox:1.36
483+ command: ["sh", "-c", "sleep 30"]
484+ resources:
485+ requests:
486+ cpu: %v`
487+
488+ func jobForInference (parallelism int , completions int , milliCPU int64 ) workloadv1beta2.AppWrapperComponent {
489+ yamlString := fmt .Sprintf (jobYAML ,
490+ randName ("job" ),
491+ parallelism ,
492+ completions ,
493+ resource .NewMilliQuantity (milliCPU , resource .DecimalSI ))
494+
495+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
496+ Expect (err ).NotTo (HaveOccurred ())
497+ return workloadv1beta2.AppWrapperComponent {
498+ Template : runtime.RawExtension {Raw : jsonBytes },
499+ }
500+ }
501+
502+ const pytorchJobYAML = `
503+ apiVersion: "kubeflow.org/v1"
504+ kind: PyTorchJob
505+ metadata:
506+ name: %v
507+ spec:
508+ pytorchReplicaSpecs:
509+ Master:
510+ restartPolicy: OnFailure
511+ template:
512+ spec:
513+ containers:
514+ - name: pytorch
515+ image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
516+ command:
517+ - "python3"
518+ - "/opt/pytorch-mnist/mnist.py"
519+ - "--epochs=1"
520+ resources:
521+ requests:
522+ cpu: %v
523+ Worker:
524+ replicas: %v
525+ restartPolicy: OnFailure
526+ template:
527+ spec:
528+ containers:
529+ - name: pytorch
530+ image: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-fc858d1
531+ command:
532+ - "python3"
533+ - "/opt/pytorch-mnist/mnist.py"
534+ - "--epochs=1"
535+ resources:
536+ requests:
537+ cpu: %v`
538+
539+ func pytorchJobForInference (masterMilliCPU int64 , workerReplicas int , workerMilliCPU int64 ) workloadv1beta2.AppWrapperComponent {
540+ yamlString := fmt .Sprintf (pytorchJobYAML ,
541+ randName ("pytorch-job" ),
542+ resource .NewMilliQuantity (masterMilliCPU , resource .DecimalSI ),
543+ workerReplicas ,
544+ resource .NewMilliQuantity (workerMilliCPU , resource .DecimalSI ))
545+
546+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
547+ Expect (err ).NotTo (HaveOccurred ())
548+ return workloadv1beta2.AppWrapperComponent {
549+ Template : runtime.RawExtension {Raw : jsonBytes },
550+ }
551+ }
552+
553+ const rayJobYAML = `
554+ apiVersion: ray.io/v1
555+ kind: RayJob
556+ metadata:
557+ name: %v
558+ spec:
559+ rayClusterSpec:
560+ headGroupSpec:
561+ template:
562+ spec:
563+ containers:
564+ - name: ray-head
565+ image: rayproject/ray:2.9.0
566+ resources:
567+ requests:
568+ cpu: 1
569+ workerGroupSpecs:
570+ - replicas: %v
571+ template:
572+ spec:
573+ containers:
574+ - name: ray-worker
575+ image: rayproject/ray:2.9.0
576+ resources:
577+ requests:
578+ cpu: %v
579+ `
580+
581+ func rayJobForInference (workerCount int , milliCPU int64 ) workloadv1beta2.AppWrapperComponent {
582+ yamlString := fmt .Sprintf (rayJobYAML ,
583+ randName ("rayjob" ),
584+ workerCount ,
585+ resource .NewMilliQuantity (milliCPU , resource .DecimalSI ))
586+
587+ jsonBytes , err := yaml .YAMLToJSON ([]byte (yamlString ))
588+ Expect (err ).NotTo (HaveOccurred ())
589+ return workloadv1beta2.AppWrapperComponent {
590+ Template : runtime.RawExtension {Raw : jsonBytes },
591+ }
592+ }
0 commit comments