diff --git a/slice/cmd/main.go b/slice/cmd/main.go index dca3d4e91..6353c28d7 100644 --- a/slice/cmd/main.go +++ b/slice/cmd/main.go @@ -41,6 +41,7 @@ import ( "tpu-slice-controller/api/v1beta1" "tpu-slice-controller/internal/controller" + "tpu-slice-controller/internal/core" "tpu-slice-controller/internal/util/cert" "tpu-slice-controller/internal/webhooks" @@ -74,6 +75,7 @@ func main() { var probeAddr string var secureMetrics bool var enableHTTP2 bool + var sliceHealthNodeAffinityMode string var tlsOpts []func(*tls.Config) flag.StringVar(&metricsAddr, "metrics-bind-address", "0", "The address the metrics endpoint binds to. "+ "Use :8443 for HTTPS or :8080 for HTTP, or leave as 0 to disable the metrics service.") @@ -95,12 +97,25 @@ func main() { flag.StringVar(&metricsCertKey, "metrics-cert-key", "tls.key", "The name of the metrics server key file.") flag.BoolVar(&enableHTTP2, "enable-http2", false, "If set, HTTP/2 will be enabled for the metrics and webhook servers") + flag.StringVar(&sliceHealthNodeAffinityMode, "default-slice-health-node-affinity", "HEALTHY", + "Default slice health node affinity. Possible values are HEALTHY or HEALTHY_AND_DEGRADED.") opts := zap.Options{ Development: true, } opts.BindFlags(flag.CommandLine) flag.Parse() + var sliceHealthValues []string + switch sliceHealthNodeAffinityMode { + case "HEALTHY": + sliceHealthValues = []string{core.TPUSliceHealthNodeSelectorHealthy} + case "HEALTHY_AND_DEGRADED": + sliceHealthValues = []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded} + default: + setupLog.Error(errors.New("invalid flag value"), "Invalid value for default-slice-health-node-affinity", "value", sliceHealthNodeAffinityMode) + os.Exit(1) + } + ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts))) // if the enable-http2 flag is false (the default), http/2 should be disabled @@ -244,7 +259,7 @@ func main() { os.Exit(1) } - go setupControllers(mgr, certsReady, activationTimeout, retryDelayOnSliceFailure) + go setupControllers(mgr, certsReady, activationTimeout, retryDelayOnSliceFailure, sliceHealthValues) setupProbeEndpoints(mgr, certsReady) @@ -255,17 +270,17 @@ func main() { } } -func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTimeout time.Duration, retryDelay time.Duration) { +func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTimeout time.Duration, retryDelay time.Duration, sliceHealthValues []string) { // The controllers won't work until the webhooks are operating, and the webhook won't work until the // certs are all in place. cert.WaitForCertsReady(setupLog, certsReady) // Register the webhooks - if err := webhooks.SetupWebhookWithManager(mgr); err != nil { + if err := webhooks.SetupWebhookWithManager(mgr, sliceHealthValues); err != nil { setupLog.Error(err, "Unable to create webhook", "webhook", "JobSet") os.Exit(1) } - if err := webhooks.SetupJobWebhookWithManager(mgr); err != nil { + if err := webhooks.SetupJobWebhookWithManager(mgr, sliceHealthValues); err != nil { setupLog.Error(err, "Unable to create webhook", "webhook", "Job") os.Exit(1) } diff --git a/slice/internal/webhooks/defaults.go b/slice/internal/webhooks/defaults.go index eea9881a7..4dee3c36b 100644 --- a/slice/internal/webhooks/defaults.go +++ b/slice/internal/webhooks/defaults.go @@ -38,7 +38,7 @@ func getTPUsRequestedPerPod(spec corev1.PodSpec) int64 { return totalTPUs } -func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec) { +func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec, defaultSliceHealthValues []string) { // 1. If there is NodeSelector with TPUSliceHealthNodeSelectorKey, we do nothing. if _, ok := template.Spec.NodeSelector[core.TPUSliceHealthNodeSelectorKey]; ok { return @@ -58,7 +58,7 @@ func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec) { } // 3. If neither of these, we add a NodeAffinity. - core.AddNodeAffinity(template, core.TPUSliceHealthNodeSelectorKey, []string{core.TPUSliceHealthNodeSelectorHealthy}) + core.AddNodeAffinity(template, core.TPUSliceHealthNodeSelectorKey, defaultSliceHealthValues) } func annotatePodTemplateSpecWithTopology(template *corev1.PodTemplateSpec, parallelism *int32, resourceName string, resourceKind string) error { diff --git a/slice/internal/webhooks/job_webhook.go b/slice/internal/webhooks/job_webhook.go index 3db1498de..f44f8e9ee 100644 --- a/slice/internal/webhooks/job_webhook.go +++ b/slice/internal/webhooks/job_webhook.go @@ -28,12 +28,16 @@ import ( "tpu-slice-controller/internal/core" ) -type JobWebhook struct{} +type JobWebhook struct { + DefaultSliceHealthValues []string +} -func SetupJobWebhookWithManager(mgr ctrl.Manager) error { +func SetupJobWebhookWithManager(mgr ctrl.Manager, defaultSliceHealthValues []string) error { return ctrl.NewWebhookManagedBy(mgr). For(&batchv1.Job{}). - WithDefaulter(&JobWebhook{}). + WithDefaulter(&JobWebhook{ + DefaultSliceHealthValues: defaultSliceHealthValues, + }). Complete() } @@ -56,7 +60,7 @@ func (r *JobWebhook) Default(ctx context.Context, obj runtime.Object) error { return nil } log.V(5).Info("Annotating Job") - annotatePodTemplateSpecWithSliceHealth(&job.Spec.Template) + annotatePodTemplateSpecWithSliceHealth(&job.Spec.Template, r.DefaultSliceHealthValues) err := annotatePodTemplateSpecWithTopology(&job.Spec.Template, job.Spec.Parallelism, job.Name, "job") if err != nil { return err diff --git a/slice/internal/webhooks/job_webhook_test.go b/slice/internal/webhooks/job_webhook_test.go index 7ac16c358..69080bce6 100644 --- a/slice/internal/webhooks/job_webhook_test.go +++ b/slice/internal/webhooks/job_webhook_test.go @@ -232,7 +232,9 @@ func TestJobDefault(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { ctx := t.Context() - webhook := &JobWebhook{} + webhook := &JobWebhook{ + DefaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy}, + } gotErr := webhook.Default(ctx, tc.job) if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" { diff --git a/slice/internal/webhooks/jobset_webhook.go b/slice/internal/webhooks/jobset_webhook.go index b3f04ca74..f091494ed 100644 --- a/slice/internal/webhooks/jobset_webhook.go +++ b/slice/internal/webhooks/jobset_webhook.go @@ -29,12 +29,16 @@ import ( ) // JobSetWebhook is the schema for your resource (ensure this matches your resource definition). -type JobSetWebhook struct{} +type JobSetWebhook struct { + DefaultSliceHealthValues []string +} -func SetupWebhookWithManager(mgr ctrl.Manager) error { +func SetupWebhookWithManager(mgr ctrl.Manager, defaultSliceHealthValues []string) error { return ctrl.NewWebhookManagedBy(mgr). For(&v1alpha2.JobSet{}). - WithDefaulter(&JobSetWebhook{}). + WithDefaulter(&JobSetWebhook{ + DefaultSliceHealthValues: defaultSliceHealthValues, + }). Complete() } @@ -59,7 +63,7 @@ func (r *JobSetWebhook) Default(ctx context.Context, obj runtime.Object) error { continue } log.V(5).Info("Annotating ReplicatedJob") - annotatePodTemplateSpecWithSliceHealth(&rj.Template.Spec.Template) + annotatePodTemplateSpecWithSliceHealth(&rj.Template.Spec.Template, r.DefaultSliceHealthValues) err := annotatePodTemplateSpecWithTopology(&rj.Template.Spec.Template, rj.Template.Spec.Parallelism, rj.Name, "replicated job") if err != nil { return err diff --git a/slice/internal/webhooks/jobset_webhook_test.go b/slice/internal/webhooks/jobset_webhook_test.go index 8f901fe8f..2ebe6c4a9 100644 --- a/slice/internal/webhooks/jobset_webhook_test.go +++ b/slice/internal/webhooks/jobset_webhook_test.go @@ -36,9 +36,10 @@ func TestDefault(t *testing.T) { ) testCases := map[string]struct { - jobSet *jobset.JobSet - wantJobSet *jobset.JobSet - wantErr error + defaultSliceHealthValues []string + jobSet *jobset.JobSet + wantJobSet *jobset.JobSet + wantErr error }{ "no queue label": { jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace). @@ -111,6 +112,7 @@ func TestDefault(t *testing.T) { Obj(), }, "should set default values": { + defaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy}, jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace). Queue("queue-name"). ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{ @@ -143,6 +145,40 @@ func TestDefault(t *testing.T) { RequestAndLimit("rj1", core.TPUResourceName, "4"). Obj(), }, + "should set default values including DEGRADED cube health": { + defaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded}, + jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace). + Queue("queue-name"). + ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{ + Name: "rj1", + Parallelism: 48, + PodAnnotations: map[string]string{ + core.TPUSliceTopologyAnnotation: "4x4x12", + }, + NodeSelector: map[string]string{ + "cloud.google.com/gke-tpu-accelerator": string(slice.TypeTpu7x), + }, + }). + RequestAndLimit("rj1", core.TPUResourceName, "4"). + Obj(), + wantJobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace). + Queue("queue-name"). + ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{ + Name: "rj1", + Parallelism: 48, + PodAnnotations: map[string]string{ + core.TPUSliceTopologyAnnotation: "4x4x12", + "kueue.x-k8s.io/podset-required-topology": "cloud.google.com/gce-topology-block", + "kueue.x-k8s.io/podset-slice-required-topology": core.TPUSubBlockLabel, + "kueue.x-k8s.io/podset-slice-size": "16", + }, + NodeSelector: map[string]string{ + "cloud.google.com/gke-tpu-accelerator": string(slice.TypeTpu7x), + }, + }).NodeAffinity("rj1", core.TPUSliceHealthNodeSelectorKey, []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded}). + RequestAndLimit("rj1", core.TPUResourceName, "4"). + Obj(), + }, "shouldn't set default values because invalid topology annotation": { jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace). Queue("queue-name"). @@ -308,7 +344,9 @@ func TestDefault(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { ctx := t.Context() - webhook := &JobSetWebhook{} + webhook := &JobSetWebhook{ + DefaultSliceHealthValues: tc.defaultSliceHealthValues, + } gotErr := webhook.Default(ctx, tc.jobSet) if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" {