diff --git a/internal/webhook/appwrapper_webhook.go b/internal/webhook/appwrapper_webhook.go index 402c029..3967a16 100644 --- a/internal/webhook/appwrapper_webhook.go +++ b/internal/webhook/appwrapper_webhook.go @@ -30,6 +30,7 @@ import ( discovery "k8s.io/client-go/discovery" "k8s.io/client-go/kubernetes" authClientv1 "k8s.io/client-go/kubernetes/typed/authorization/v1" + utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log" @@ -44,6 +45,11 @@ import ( "github.com/project-codeflare/appwrapper/pkg/utils" ) +const ( + AppWrapperUsernameLabel = "workload.codeflare.dev/user" + AppWrapperUserIDLabel = "workload.codeflare.dev/userid" +) + type AppWrapperWebhook struct { Config *config.AppWrapperConfig SubjectAccessReviewer authClientv1.SubjectAccessReviewInterface @@ -66,6 +72,14 @@ func (w *AppWrapperWebhook) Default(ctx context.Context, obj runtime.Object) err log.FromContext(ctx).Info("Error raised during podSet inference", "job", aw) return err } + + // inject labels with user name and id + request, err := admission.RequestFromContext(ctx) + if err != nil { + return err + } + userInfo := request.UserInfo + aw.Labels = utilmaps.MergeKeepFirst(map[string]string{AppWrapperUsernameLabel: userInfo.Username, AppWrapperUserIDLabel: userInfo.UID}, aw.Labels) return nil } @@ -258,6 +272,14 @@ func (w *AppWrapperWebhook) validateAppWrapperUpdate(old *workloadv1beta2.AppWra } } + // ensure user name and id are not mutated + if old.Labels[AppWrapperUsernameLabel] != new.Labels[AppWrapperUsernameLabel] { + allErrors = append(allErrors, field.Forbidden(field.NewPath("metadata").Child("labels").Key(AppWrapperUsernameLabel), msg)) + } + if old.Labels[AppWrapperUserIDLabel] != new.Labels[AppWrapperUserIDLabel] { + allErrors = append(allErrors, field.Forbidden(field.NewPath("metadata").Child("labels").Key(AppWrapperUserIDLabel), msg)) + } + return allErrors } diff --git a/internal/webhook/appwrapper_webhook_test.go b/internal/webhook/appwrapper_webhook_test.go index abfb109..aeca492 100644 --- a/internal/webhook/appwrapper_webhook_test.go +++ b/internal/webhook/appwrapper_webhook_test.go @@ -27,6 +27,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" + utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ) var _ = Describe("AppWrapper Webhook Tests", func() { @@ -39,6 +40,16 @@ var _ = Describe("AppWrapper Webhook Tests", func() { Expect(aw.Spec.Suspend).Should(BeTrue(), "aw.Spec.Suspend should have been changed to true") Expect(k8sClient.Delete(ctx, aw)).To(Succeed()) }) + + It("User name and ID are set", func() { + aw := toAppWrapper(pod(100)) + aw.Labels = utilmaps.MergeKeepFirst(map[string]string{AppWrapperUsernameLabel: "bad", AppWrapperUserIDLabel: "bad"}, aw.Labels) + + Expect(k8sLimitedClient.Create(ctx, aw)).To(Succeed()) + Expect(aw.Labels[AppWrapperUsernameLabel]).Should(BeIdenticalTo(limitedUserName)) + Expect(aw.Labels[AppWrapperUserIDLabel]).Should(BeIdenticalTo(limitedUserID)) + Expect(k8sLimitedClient.Delete(ctx, aw)).To(Succeed()) + }) }) Context("Validating Webhook", func() { @@ -128,6 +139,36 @@ var _ = Describe("AppWrapper Webhook Tests", func() { Expect(k8sClient.Create(ctx, aw)).ShouldNot(Succeed()) }) + It("User name and ID are immutable", func() { + aw := toAppWrapper(pod(100)) + awName := types.NamespacedName{Name: aw.Name, Namespace: aw.Namespace} + Expect(k8sClient.Create(ctx, aw)).Should(Succeed()) + + aw = getAppWrapper(awName) + aw.Labels[AppWrapperUsernameLabel] = "bad" + Expect(k8sClient.Update(ctx, aw)).ShouldNot(Succeed()) + + aw = getAppWrapper(awName) + aw.Labels[AppWrapperUserIDLabel] = "bad" + Expect(k8sClient.Update(ctx, aw)).ShouldNot(Succeed()) + + Expect(k8sClient.Delete(ctx, aw)).To(Succeed()) + }) + + It("User name and ID should be preserved on updates", func() { + aw := toAppWrapper(pod(100)) + awName := types.NamespacedName{Name: aw.Name, Namespace: aw.Namespace} + Expect(k8sLimitedClient.Create(ctx, aw)).Should(Succeed()) + + aw = getAppWrapper(awName) + Expect(k8sClient.Update(ctx, aw)).Should(Succeed()) + + aw = getAppWrapper(awName) + Expect(aw.Labels[AppWrapperUsernameLabel]).Should(BeIdenticalTo(limitedUserName)) + Expect(aw.Labels[AppWrapperUserIDLabel]).Should(BeIdenticalTo(limitedUserID)) + Expect(k8sLimitedClient.Delete(ctx, aw)).To(Succeed()) + }) + Context("aw.Spec.Components is immutable", func() { It("Updates to non-sensitive fields are allowed", func() { aw := toAppWrapper(pod(100), deployment(4, 100)) diff --git a/internal/webhook/suite_test.go b/internal/webhook/suite_test.go index 04cc4e3..94763ae 100644 --- a/internal/webhook/suite_test.go +++ b/internal/webhook/suite_test.go @@ -60,6 +60,9 @@ var testEnv *envtest.Environment var ctx context.Context var cancel context.CancelFunc +const limitedUserName = "limited-user" +const limitedUserID = "8da0fcfe-6d7f-4f44-b433-d91d22cc1b8c" + func TestControllers(t *testing.T) { RegisterFailHandler(Fail) @@ -115,9 +118,8 @@ var _ = BeforeSuite(func() { Expect(k8sClient).NotTo(BeNil()) // configure a restricted rbac user who can create AppWrappers and Pods but not Deployments - limitedUserName := "limited-user" limitedCfg := *cfg - limitedCfg.Impersonate = rest.ImpersonationConfig{UserName: limitedUserName, Extra: map[string][]string{"xyzzy": {"plugh"}}} + limitedCfg.Impersonate = rest.ImpersonationConfig{UserName: limitedUserName, UID: string(limitedUserID), Extra: map[string][]string{"xyzzy": {"plugh"}}} _, err = testEnv.AddUser(envtest.User{Name: limitedUserName, Groups: []string{}}, &limitedCfg) Expect(err).NotTo(HaveOccurred()) clusterRole := &rbacv1.ClusterRole{