Skip to content

Commit da63f05

Browse files
authored
Natasha/validate priority exists for podgroup (#330)
* podgrouper - when calculating priority class name for podgroup, validating the priority class name exists * fixed tests * some refactoring - adding kubeclient to params of NewDefaultGrouper * added some tests * added generated RBAC file * CR fix
1 parent 33bd9d0 commit da63f05

File tree

21 files changed

+385
-157
lines changed

21 files changed

+385
-157
lines changed

deployments/kai-scheduler/templates/rbac/podgrouper.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ rules:
276276
- create
277277
- patch
278278
- update
279+
- apiGroups:
280+
- scheduling.k8s.io
281+
resources:
282+
- priorityclasses
283+
verbs:
284+
- get
285+
- list
286+
- watch
279287
- apiGroups:
280288
- scheduling.run.ai
281289
resources:

pkg/podgrouper/pod_controller.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type Configs struct {
6161
// +kubebuilder:rbac:groups="",resources=configmaps,verbs=get;list;watch
6262
// +kubebuilder:rbac:groups="",resources=pods,verbs=create;update;patch
6363
// +kubebuilder:rbac:groups="",resources=events,verbs=create;patch;update;get;list;watch
64+
// +kubebuilder:rbac:groups="scheduling.k8s.io",resources=priorityclasses,verbs=get;list;watch
6465
// +kubebuilder:rbac:groups="scheduling.run.ai",resources=podgroups,verbs=create;update;patch;get;list;watch
6566

6667
// Reconcile is part of the main kubernetes reconciliation loop which aims to

pkg/podgrouper/podgrouper/hub/hub.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ func (ph *PluginsHub) GetPodGrouperPlugin(gvk metav1.GroupVersionKind) grouper.G
7676
func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
7777
gangScheduleKnative bool, queueLabelKey, nodePoolLabelKey string,
7878
defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string) *PluginsHub {
79-
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
80-
defaultGrouper.SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace, kubeClient)
79+
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, kubeClient)
80+
defaultGrouper.SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace)
8181

8282
kubeFlowDistributedGrouper := kubeflow.NewKubeflowDistributedGrouper(defaultGrouper)
8383
mpiGrouper := mpi.NewMpiGrouper(kubeClient, kubeFlowDistributedGrouper)

pkg/podgrouper/podgrouper/plugins/aml/aml_grouper_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
v1 "k8s.io/api/core/v1"
1111
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
12+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1213

1314
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
1415
)
@@ -47,7 +48,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
4748
}
4849
pod := &v1.Pod{}
4950

50-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
51+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient()))
5152
podGroupMetadata, err := amlGrouper.GetPodGroupMetadata(owner, pod)
5253

5354
assert.Nil(t, err)
@@ -89,7 +90,7 @@ func TestGetPodGroupMetadataWithoutReplicas(t *testing.T) {
8990
}
9091
pod := &v1.Pod{}
9192

92-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
93+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient()))
9394
_, err := amlGrouper.GetPodGroupMetadata(owner, pod)
9495

9596
assert.NotNil(t, err)

pkg/podgrouper/podgrouper/plugins/cronjobs/cronjob_grouper_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
6767
assert.Nil(t, err)
6868

6969
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(job).Build()
70-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
70+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
7171
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
7272

7373
assert.Nil(t, err)
@@ -105,7 +105,7 @@ func TestGetPodGroupMetadataJobOwnerNotFound(t *testing.T) {
105105
}
106106

107107
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
108-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
108+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
109109
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
110110

111111
assert.NotNil(t, err)
@@ -144,7 +144,7 @@ func TestGetPodGroupMetadataJobNotExists(t *testing.T) {
144144
assert.Nil(t, err)
145145

146146
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
147-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
147+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
148148
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
149149
assert.NotNil(t, err)
150150
assert.Equal(t, fmt.Sprintf("jobs.batch \"%s\" not found", jobName), err.Error())

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"golang.org/x/exp/maps"
1212
v1 "k8s.io/api/core/v1"
13+
schedulingv1 "k8s.io/api/scheduling/v1"
1314
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1415
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
1516
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -36,19 +37,17 @@ type DefaultGrouper struct {
3637
kubeReader client.Reader
3738
}
3839

39-
func NewDefaultGrouper(queueLabelKey, nodePoolLabelKey string) *DefaultGrouper {
40+
func NewDefaultGrouper(queueLabelKey, nodePoolLabelKey string, kubeReader client.Reader) *DefaultGrouper {
4041
return &DefaultGrouper{
4142
queueLabelKey: queueLabelKey,
4243
nodePoolLabelKey: nodePoolLabelKey,
44+
kubeReader: kubeReader,
4345
}
4446
}
4547

46-
func (dg *DefaultGrouper) SetDefaultPrioritiesConfigMapParams(
47-
defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string, kubeReader client.Reader,
48-
) {
48+
func (dg *DefaultGrouper) SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string) {
4949
dg.defaultPrioritiesConfigMapName = defaultPrioritiesConfigMapName
5050
dg.defaultPrioritiesConfigMapNamespace = defaultPrioritiesConfigMapNamespace
51-
dg.kubeReader = kubeReader
5251
}
5352

5453
func (dg *DefaultGrouper) Name() string {
@@ -156,29 +155,63 @@ func (dg *DefaultGrouper) calculateQueueName(topOwner *unstructured.Unstructured
156155

157156
func (dg *DefaultGrouper) CalcPodGroupPriorityClass(topOwner *unstructured.Unstructured, pod *v1.Pod,
158157
defaultPriorityClassForJob string) string {
158+
priorityClassName := dg.calcPodGroupPriorityClass(topOwner, pod)
159+
if dg.validatePriorityClassExists(priorityClassName) {
160+
return priorityClassName
161+
}
162+
163+
if priorityClassName != "" {
164+
logger.V(2).Info("priorityClassName from pod or owner labels is not valid, falling back to default",
165+
"priorityClassName", priorityClassName, "topOwner", topOwner.GetName(), "pod", pod.GetName())
166+
}
167+
168+
groupKind := topOwner.GroupVersionKind().GroupKind()
169+
priorityClassName = dg.getDefaultPriorityClassNameForKind(&groupKind)
170+
if dg.validatePriorityClassExists(priorityClassName) {
171+
return priorityClassName
172+
}
173+
174+
logger.V(2).Info("No default priority class found for group kind, using default fallback",
175+
"groupKind", groupKind.String(), "defaultFallback", defaultPriorityClassForJob)
176+
return defaultPriorityClassForJob
177+
}
178+
179+
func (dg *DefaultGrouper) calcPodGroupPriorityClass(topOwner *unstructured.Unstructured, pod *v1.Pod) string {
159180
if priorityClassName, found := topOwner.GetLabels()[constants.PriorityLabelKey]; found {
160181
return priorityClassName
161182
} else if priorityClassName, found = pod.GetLabels()[constants.PriorityLabelKey]; found {
162183
return priorityClassName
163184
} else if len(pod.Spec.PriorityClassName) != 0 {
164185
return pod.Spec.PriorityClassName
165-
} else {
166-
groupKind := topOwner.GroupVersionKind().GroupKind()
167-
return dg.getDefaultPriorityClassNameForKind(&groupKind, defaultPriorityClassForJob)
168186
}
187+
return ""
188+
}
189+
190+
func (dg *DefaultGrouper) validatePriorityClassExists(priorityClassName string) bool {
191+
if priorityClassName == "" || dg.kubeReader == nil {
192+
return false
193+
}
194+
195+
priorityClass := &schedulingv1.PriorityClass{}
196+
err := dg.kubeReader.Get(context.Background(), client.ObjectKey{Name: priorityClassName}, priorityClass)
197+
if err != nil {
198+
logger.V(1).Error(err, "Failed to get priority class", "priorityClassName", priorityClassName)
199+
return false
200+
}
201+
return true
169202
}
170203

171204
// getDefaultPriorityClassNameForKind - returns the default priority class name for a given group kind.
172-
func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.GroupKind, defaultPriorityClassFallback string) string {
205+
func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.GroupKind) string {
173206
if groupKind == nil || groupKind.String() == "" || groupKind.Kind == "" {
174207
logger.V(3).Info("Unable to get default priority class name: GroupKind is empty, using default priority class fallback")
175-
return defaultPriorityClassFallback
208+
return ""
176209
}
177210

178211
defaultPriorities, err := dg.getDefaultPrioritiesPerTypeMapping()
179212
if err != nil {
180213
logger.V(1).Error(err, "Unable to get default priorities mapping")
181-
return defaultPriorityClassFallback
214+
return ""
182215
}
183216

184217
// Check if the groupKind is in the default priorities map.
@@ -192,9 +225,7 @@ func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.G
192225
return priorityClassName
193226
}
194227

195-
logger.V(4).Info("No default priority class found for group kind, using default fallback",
196-
"groupKind", groupKind.String(), "defaultFallback", defaultPriorityClassFallback)
197-
return defaultPriorityClassFallback
228+
return ""
198229
}
199230

200231
// getDefaultPrioritiesPerTypeMapping - returns a map of workload type to default priority class name.

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper_owners_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
v1 "k8s.io/api/core/v1"
1111
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
12+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1213
)
1314

1415
const (
@@ -57,7 +58,7 @@ func TestGetPodGroupMetadata_KubeflowPipelineScheduledWorkflow(t *testing.T) {
5758
}
5859
pod := &v1.Pod{}
5960

60-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
61+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
6162
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
6263

6364
assert.Nil(t, err)
@@ -129,7 +130,7 @@ func TestGetPodGroupMetadata_ArgoWorkflow(t *testing.T) {
129130
}
130131
pod := &v1.Pod{}
131132

132-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
133+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
133134
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
134135

135136
assert.Nil(t, err)
@@ -198,7 +199,7 @@ func TestGetPodGroupMetadata_Tekton_TaskRun(t *testing.T) {
198199
}
199200
pod := &v1.Pod{}
200201

201-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
202+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
202203
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
203204

204205
assert.Nil(t, err)
@@ -296,7 +297,7 @@ func TestGetPodGroupMetadata_Tekton_PipelineRun(t *testing.T) {
296297
}
297298
pod := &v1.Pod{}
298299

299-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
300+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
300301
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
301302

302303
assert.Nil(t, err)

0 commit comments

Comments
 (0)