Skip to content

Commit f40324a

Browse files
authored
Scheduler and PodGrouper use configurable nodepool label key (#179)
Scheduler and PodGrouper use configurable nodepool label key
1 parent db692db commit f40324a

32 files changed

+214
-170
lines changed

cmd/podgrouper/app/options.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package app
66
import (
77
"flag"
88

9-
"github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
109
controllers "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper"
1110
)
1211

@@ -28,7 +27,7 @@ func (o *Options) AddFlags(fs *flag.FlagSet) {
2827
fs.StringVar(&o.MetricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
2928
fs.StringVar(&o.ProbeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
3029
fs.BoolVar(&o.EnableLeaderElection, "leader-elect", false, "Enable leader election for controller manager. Enabling this will ensure there is only one active controller manager.")
31-
fs.StringVar(&o.NodePoolLabelKey, "nodepool-label-key", constants.NodePoolNameLabel, "The label key for node pools")
30+
fs.StringVar(&o.NodePoolLabelKey, "nodepool-label-key", "", "The label key for node pools")
3231
fs.IntVar(&o.QPS, "qps", 50, "Queries per second to the K8s API server")
3332
fs.IntVar(&o.Burst, "burst", 300, "Burst to the K8s API server")
3433
fs.IntVar(&o.MaxConcurrentReconciles, "max-concurrent-reconciles", 10, "Max concurrent reconciles")

pkg/common/constants/constants.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ const (
2626
LastStartTimeStamp = "kai.scheduler/last-start-timestamp"
2727

2828
// Labels
29-
NodePoolNameLabel = "runai/node-pool"
3029
GPUGroup = "runai-gpu-group"
3130
MultiGpuGroupLabelPrefix = GPUGroup + "/"
3231
MigEnabledLabel = "node-role.kubernetes.io/runai-mig-enabled"

pkg/podgrouper/pod_controller.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
103103
return ctrl.Result{}, err
104104
}
105105

106-
addNodePoolLabel(metadata, &pod, r.configs.NodePoolLabelKey)
106+
if len(r.configs.NodePoolLabelKey) > 0 {
107+
addNodePoolLabel(metadata, &pod, r.configs.NodePoolLabelKey)
108+
}
107109

108110
err = r.PodGroupHandler.ApplyToCluster(ctx, *metadata)
109111
if err != nil {
@@ -128,7 +130,7 @@ func (r *PodReconciler) SetupWithManager(mgr ctrl.Manager, configs Configs) erro
128130
}
129131

130132
r.podGrouper = podgrouper.NewPodgrouper(mgr.GetClient(), clientWithoutCache, configs.SearchForLegacyPodGroups,
131-
configs.KnativeGangSchedule, configs.SchedulingQueueLabelKey)
133+
configs.KnativeGangSchedule, configs.SchedulingQueueLabelKey, configs.NodePoolLabelKey)
132134
r.PodGroupHandler = podgroup.NewHandler(mgr.GetClient(), configs.NodePoolLabelKey, configs.SchedulingQueueLabelKey)
133135
r.configs = configs
134136
r.eventRecorder = mgr.GetEventRecorderFor(controllerName)

pkg/podgrouper/pod_controller_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgroup"
2323
)
2424

25-
const nodePoolKey = "runai/node-pool"
25+
const nodePoolKey = "kai.scheduler/node-pool"
2626

2727
func TestAddNodePoolLabel(t *testing.T) {
2828
metadata := podgroup.Metadata{

pkg/podgrouper/podgrouper/hub/hub.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ func (ph *PluginsHub) GetPodGrouperPlugin(gvk metav1.GroupVersionKind) grouper.G
7171
}
7272

7373
func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
74-
gangScheduleKnative bool, queueLabelKey string) *PluginsHub {
75-
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey)
74+
gangScheduleKnative bool, queueLabelKey, nodePoolLabelKey string) *PluginsHub {
75+
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
7676

7777
kubeFlowDistributedGrouper := kubeflow.NewKubeflowDistributedGrouper(defaultGrouper)
7878
mpiGrouper := mpi.NewMpiGrouper(kubeClient, kubeFlowDistributedGrouper)

pkg/podgrouper/podgrouper/hub/hub_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import (
1515
)
1616

1717
const (
18-
queueLabelKey = "kai.scheduler/queue"
18+
queueLabelKey = "kai.scheduler/queue"
19+
nodePoolLabelKey = "kai.scheduler/node-pool"
1920
)
2021

2122
func TestSupportedTypes(t *testing.T) {
@@ -32,7 +33,9 @@ var _ = Describe("SupportedTypes", func() {
3233

3334
BeforeEach(func() {
3435
kubeClient = fake.NewFakeClient()
35-
hub = NewPluginsHub(kubeClient, false, false, queueLabelKey)
36+
hub = NewPluginsHub(
37+
kubeClient, false, false, queueLabelKey, nodePoolLabelKey,
38+
)
3639
})
3740

3841
It("should return plugin for exact GVK match", func() {
@@ -66,7 +69,9 @@ var _ = Describe("SupportedTypes", func() {
6669

6770
BeforeEach(func() {
6871
kubeClient = fake.NewFakeClient()
69-
hub = NewPluginsHub(kubeClient, false, false, queueLabelKey)
72+
hub = NewPluginsHub(
73+
kubeClient, false, false, queueLabelKey, nodePoolLabelKey,
74+
)
7075
})
7176

7277
It("should successfully retrieve with any version for kind set with wildcard", func() {

pkg/podgrouper/podgrouper/plugins/aml/aml_grouper_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ import (
1313
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
1414
)
1515

16-
const queueLabelKey = "kai.scheduler/queue"
16+
const (
17+
queueLabelKey = "kai.scheduler/queue"
18+
nodePoolLabelKey = "kai.scheduler/node-pool"
19+
)
1720

1821
func TestGetPodGroupMetadata(t *testing.T) {
1922
owner := &unstructured.Unstructured{
@@ -44,7 +47,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
4447
}
4548
pod := &v1.Pod{}
4649

47-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey))
50+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
4851
podGroupMetadata, err := amlGrouper.GetPodGroupMetadata(owner, pod)
4952

5053
assert.Nil(t, err)
@@ -86,7 +89,7 @@ func TestGetPodGroupMetadataWithoutReplicas(t *testing.T) {
8689
}
8790
pod := &v1.Pod{}
8891

89-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey))
92+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
9093
_, err := amlGrouper.GetPodGroupMetadata(owner, pod)
9194

9295
assert.NotNil(t, err)

pkg/podgrouper/podgrouper/plugins/cronjobs/cronjob_grouper_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ import (
2222
)
2323

2424
const (
25-
podName = "my-pod"
26-
jobName = "cron-27958584"
27-
jobUID = types.UID("123456789")
28-
cronjobName = "cron"
29-
queueLabelKey = "kai.scheduler/queue"
25+
podName = "my-pod"
26+
jobName = "cron-27958584"
27+
jobUID = types.UID("123456789")
28+
cronjobName = "cron"
29+
queueLabelKey = "kai.scheduler/queue"
30+
nodePoolLabelKey = "kai.scheduler/node-pool"
3031
)
3132

3233
func TestGetPodGroupMetadata(t *testing.T) {
@@ -66,7 +67,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
6667
assert.Nil(t, err)
6768

6869
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(job).Build()
69-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey))
70+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
7071
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
7172

7273
assert.Nil(t, err)
@@ -104,7 +105,7 @@ func TestGetPodGroupMetadataJobOwnerNotFound(t *testing.T) {
104105
}
105106

106107
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
107-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey))
108+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
108109
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
109110

110111
assert.NotNil(t, err)
@@ -143,7 +144,7 @@ func TestGetPodGroupMetadataJobNotExists(t *testing.T) {
143144
assert.Nil(t, err)
144145

145146
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
146-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey))
147+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
147148
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
148149
assert.NotNil(t, err)
149150
assert.Equal(t, fmt.Sprintf("jobs.batch \"%s\" not found", jobName), err.Error())

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ var (
2525
)
2626

2727
type DefaultGrouper struct {
28-
queueLabelKey string
28+
queueLabelKey string
29+
nodePoolLabelKey string
2930
}
3031

31-
func NewDefaultGrouper(queueLabelKey string) *DefaultGrouper {
32+
func NewDefaultGrouper(queueLabelKey, nodePoolLabelKey string) *DefaultGrouper {
3233
return &DefaultGrouper{
33-
queueLabelKey: queueLabelKey,
34+
queueLabelKey: queueLabelKey,
35+
nodePoolLabelKey: nodePoolLabelKey,
3436
}
3537
}
3638

@@ -106,15 +108,15 @@ func (dg *DefaultGrouper) CalcPodGroupQueue(topOwner *unstructured.Unstructured,
106108
return queue
107109
}
108110

109-
queue := calculateQueueName(topOwner, pod)
111+
queue := dg.calculateQueueName(topOwner, pod)
110112
if queue != "" {
111113
return queue
112114
}
113115

114116
return constants.DefaultQueueName
115117
}
116118

117-
func calculateQueueName(topOwner *unstructured.Unstructured, pod *v1.Pod) string {
119+
func (dg *DefaultGrouper) calculateQueueName(topOwner *unstructured.Unstructured, pod *v1.Pod) string {
118120
project := ""
119121
if projectLabel, found := topOwner.GetLabels()[constants.ProjectLabelKey]; found {
120122
project = projectLabel
@@ -126,7 +128,7 @@ func calculateQueueName(topOwner *unstructured.Unstructured, pod *v1.Pod) string
126128
return ""
127129
}
128130

129-
if nodePool, found := pod.GetLabels()[commonconsts.NodePoolNameLabel]; found {
131+
if nodePool, found := pod.GetLabels()[dg.nodePoolLabelKey]; found {
130132
return fmt.Sprintf("%s-%s", project, nodePool)
131133
}
132134

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper_owners_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ import (
1111
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
1212
)
1313

14-
const queueLabelKey = "kai.scheduler/queue"
14+
const (
15+
queueLabelKey = "kai.scheduler/queue"
16+
nodePoolLabelKey = "kai.scheduler/node-pool"
17+
)
1518

1619
func TestGetPodGroupMetadata_KubeflowPipelineScheduledWorkflow(t *testing.T) {
1720
owner := &unstructured.Unstructured{
@@ -54,7 +57,7 @@ func TestGetPodGroupMetadata_KubeflowPipelineScheduledWorkflow(t *testing.T) {
5457
}
5558
pod := &v1.Pod{}
5659

57-
defaultGrouper := NewDefaultGrouper(queueLabelKey)
60+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
5861
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
5962

6063
assert.Nil(t, err)
@@ -126,7 +129,7 @@ func TestGetPodGroupMetadata_ArgoWorkflow(t *testing.T) {
126129
}
127130
pod := &v1.Pod{}
128131

129-
defaultGrouper := NewDefaultGrouper(queueLabelKey)
132+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
130133
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
131134

132135
assert.Nil(t, err)
@@ -195,7 +198,7 @@ func TestGetPodGroupMetadata_Tekton_TaskRun(t *testing.T) {
195198
}
196199
pod := &v1.Pod{}
197200

198-
defaultGrouper := NewDefaultGrouper(queueLabelKey)
201+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
199202
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
200203

201204
assert.Nil(t, err)
@@ -293,7 +296,7 @@ func TestGetPodGroupMetadata_Tekton_PipelineRun(t *testing.T) {
293296
}
294297
pod := &v1.Pod{}
295298

296-
defaultGrouper := NewDefaultGrouper(queueLabelKey)
299+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
297300
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
298301

299302
assert.Nil(t, err)

0 commit comments

Comments
 (0)