Skip to content

Commit 3d961a4

Browse files
authored
cherry pick - validate priorityclass exists in podgrouper (#336)
1 parent 2cec1db commit 3d961a4

File tree

20 files changed

+383
-155
lines changed

20 files changed

+383
-155
lines changed

deployments/kai-scheduler/templates/rbac/podgrouper.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ rules:
239239
- create
240240
- patch
241241
- update
242+
- apiGroups:
243+
- scheduling.k8s.io
244+
resources:
245+
- priorityclasses
246+
verbs:
247+
- get
248+
- list
249+
- watch
242250
- apiGroups:
243251
- scheduling.run.ai
244252
resources:

pkg/podgrouper/pod_controller.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type Configs struct {
5757
// +kubebuilder:rbac:groups="",resources=configmaps,verbs=get;list;watch
5858
// +kubebuilder:rbac:groups="",resources=pods,verbs=create;update;patch
5959
// +kubebuilder:rbac:groups="",resources=events,verbs=create;patch;update;get;list;watch
60+
// +kubebuilder:rbac:groups="scheduling.k8s.io",resources=priorityclasses,verbs=get;list;watch
6061
// +kubebuilder:rbac:groups="scheduling.run.ai",resources=podgroups,verbs=create;update;patch;get;list;watch
6162

6263
// Reconcile is part of the main kubernetes reconciliation loop which aims to

pkg/podgrouper/podgrouper/hub/hub.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ func (ph *PluginsHub) GetPodGrouperPlugin(gvk metav1.GroupVersionKind) grouper.G
7575
func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
7676
gangScheduleKnative bool, queueLabelKey, nodePoolLabelKey string,
7777
defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string) *PluginsHub {
78-
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
79-
defaultGrouper.SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace, kubeClient)
78+
defaultGrouper := defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, kubeClient)
79+
defaultGrouper.SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace)
8080

8181
kubeFlowDistributedGrouper := kubeflow.NewKubeflowDistributedGrouper(defaultGrouper)
8282
mpiGrouper := mpi.NewMpiGrouper(kubeClient, kubeFlowDistributedGrouper)

pkg/podgrouper/podgrouper/plugins/aml/aml_grouper_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
v1 "k8s.io/api/core/v1"
1111
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
12+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1213

1314
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
1415
)
@@ -47,7 +48,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
4748
}
4849
pod := &v1.Pod{}
4950

50-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
51+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient()))
5152
podGroupMetadata, err := amlGrouper.GetPodGroupMetadata(owner, pod)
5253

5354
assert.Nil(t, err)
@@ -89,7 +90,7 @@ func TestGetPodGroupMetadataWithoutReplicas(t *testing.T) {
8990
}
9091
pod := &v1.Pod{}
9192

92-
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
93+
amlGrouper := NewAmlGrouper(defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient()))
9394
_, err := amlGrouper.GetPodGroupMetadata(owner, pod)
9495

9596
assert.NotNil(t, err)

pkg/podgrouper/podgrouper/plugins/cronjobs/cronjob_grouper_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestGetPodGroupMetadata(t *testing.T) {
6767
assert.Nil(t, err)
6868

6969
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(job).Build()
70-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
70+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
7171
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
7272

7373
assert.Nil(t, err)
@@ -105,7 +105,7 @@ func TestGetPodGroupMetadataJobOwnerNotFound(t *testing.T) {
105105
}
106106

107107
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
108-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
108+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
109109
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
110110

111111
assert.NotNil(t, err)
@@ -144,7 +144,7 @@ func TestGetPodGroupMetadataJobNotExists(t *testing.T) {
144144
assert.Nil(t, err)
145145

146146
client := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithRuntimeObjects(cronjob).Build()
147-
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey))
147+
grouper := NewCronJobGrouper(client, defaultgrouper.NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, client))
148148
podGroupMetadata, err := grouper.GetPodGroupMetadata(cronjob, pod)
149149
assert.NotNil(t, err)
150150
assert.Equal(t, fmt.Sprintf("jobs.batch \"%s\" not found", jobName), err.Error())

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"golang.org/x/exp/maps"
1212
v1 "k8s.io/api/core/v1"
13+
schedulingv1 "k8s.io/api/scheduling/v1"
1314
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1415
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
1516
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -36,19 +37,17 @@ type DefaultGrouper struct {
3637
kubeReader client.Reader
3738
}
3839

39-
func NewDefaultGrouper(queueLabelKey, nodePoolLabelKey string) *DefaultGrouper {
40+
func NewDefaultGrouper(queueLabelKey, nodePoolLabelKey string, kubeReader client.Reader) *DefaultGrouper {
4041
return &DefaultGrouper{
4142
queueLabelKey: queueLabelKey,
4243
nodePoolLabelKey: nodePoolLabelKey,
44+
kubeReader: kubeReader,
4345
}
4446
}
4547

46-
func (dg *DefaultGrouper) SetDefaultPrioritiesConfigMapParams(
47-
defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string, kubeReader client.Reader,
48-
) {
48+
func (dg *DefaultGrouper) SetDefaultPrioritiesConfigMapParams(defaultPrioritiesConfigMapName, defaultPrioritiesConfigMapNamespace string) {
4949
dg.defaultPrioritiesConfigMapName = defaultPrioritiesConfigMapName
5050
dg.defaultPrioritiesConfigMapNamespace = defaultPrioritiesConfigMapNamespace
51-
dg.kubeReader = kubeReader
5251
}
5352

5453
func (dg *DefaultGrouper) Name() string {
@@ -151,29 +150,63 @@ func (dg *DefaultGrouper) calculateQueueName(topOwner *unstructured.Unstructured
151150

152151
func (dg *DefaultGrouper) CalcPodGroupPriorityClass(topOwner *unstructured.Unstructured, pod *v1.Pod,
153152
defaultPriorityClassForJob string) string {
153+
priorityClassName := dg.calcPodGroupPriorityClass(topOwner, pod)
154+
if dg.validatePriorityClassExists(priorityClassName) {
155+
return priorityClassName
156+
}
157+
158+
if priorityClassName != "" {
159+
logger.V(2).Info("priorityClassName from pod or owner labels is not valid, falling back to default",
160+
"priorityClassName", priorityClassName, "topOwner", topOwner.GetName(), "pod", pod.GetName())
161+
}
162+
163+
groupKind := topOwner.GroupVersionKind().GroupKind()
164+
priorityClassName = dg.getDefaultPriorityClassNameForKind(&groupKind)
165+
if dg.validatePriorityClassExists(priorityClassName) {
166+
return priorityClassName
167+
}
168+
169+
logger.V(2).Info("No default priority class found for group kind, using default fallback",
170+
"groupKind", groupKind.String(), "defaultFallback", defaultPriorityClassForJob)
171+
return defaultPriorityClassForJob
172+
}
173+
174+
func (dg *DefaultGrouper) calcPodGroupPriorityClass(topOwner *unstructured.Unstructured, pod *v1.Pod) string {
154175
if priorityClassName, found := topOwner.GetLabels()[constants.PriorityLabelKey]; found {
155176
return priorityClassName
156177
} else if priorityClassName, found = pod.GetLabels()[constants.PriorityLabelKey]; found {
157178
return priorityClassName
158179
} else if len(pod.Spec.PriorityClassName) != 0 {
159180
return pod.Spec.PriorityClassName
160-
} else {
161-
groupKind := topOwner.GroupVersionKind().GroupKind()
162-
return dg.getDefaultPriorityClassNameForKind(&groupKind, defaultPriorityClassForJob)
163181
}
182+
return ""
183+
}
184+
185+
func (dg *DefaultGrouper) validatePriorityClassExists(priorityClassName string) bool {
186+
if priorityClassName == "" || dg.kubeReader == nil {
187+
return false
188+
}
189+
190+
priorityClass := &schedulingv1.PriorityClass{}
191+
err := dg.kubeReader.Get(context.Background(), client.ObjectKey{Name: priorityClassName}, priorityClass)
192+
if err != nil {
193+
logger.V(1).Error(err, "Failed to get priority class", "priorityClassName", priorityClassName)
194+
return false
195+
}
196+
return true
164197
}
165198

166199
// getDefaultPriorityClassNameForKind - returns the default priority class name for a given group kind.
167-
func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.GroupKind, defaultPriorityClassFallback string) string {
200+
func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.GroupKind) string {
168201
if groupKind == nil || groupKind.String() == "" || groupKind.Kind == "" {
169202
logger.V(3).Info("Unable to get default priority class name: GroupKind is empty, using default priority class fallback")
170-
return defaultPriorityClassFallback
203+
return ""
171204
}
172205

173206
defaultPriorities, err := dg.getDefaultPrioritiesPerTypeMapping()
174207
if err != nil {
175208
logger.V(1).Error(err, "Unable to get default priorities mapping")
176-
return defaultPriorityClassFallback
209+
return ""
177210
}
178211

179212
// Check if the groupKind is in the default priorities map.
@@ -187,9 +220,7 @@ func (dg *DefaultGrouper) getDefaultPriorityClassNameForKind(groupKind *schema.G
187220
return priorityClassName
188221
}
189222

190-
logger.V(4).Info("No default priority class found for group kind, using default fallback",
191-
"groupKind", groupKind.String(), "defaultFallback", defaultPriorityClassFallback)
192-
return defaultPriorityClassFallback
223+
return ""
193224
}
194225

195226
// getDefaultPrioritiesPerTypeMapping - returns a map of workload type to default priority class name.

pkg/podgrouper/podgrouper/plugins/defaultgrouper/default_grouper_owners_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
v1 "k8s.io/api/core/v1"
1111
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
12+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1213
)
1314

1415
const (
@@ -57,7 +58,7 @@ func TestGetPodGroupMetadata_KubeflowPipelineScheduledWorkflow(t *testing.T) {
5758
}
5859
pod := &v1.Pod{}
5960

60-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
61+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
6162
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
6263

6364
assert.Nil(t, err)
@@ -129,7 +130,7 @@ func TestGetPodGroupMetadata_ArgoWorkflow(t *testing.T) {
129130
}
130131
pod := &v1.Pod{}
131132

132-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
133+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
133134
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
134135

135136
assert.Nil(t, err)
@@ -198,7 +199,7 @@ func TestGetPodGroupMetadata_Tekton_TaskRun(t *testing.T) {
198199
}
199200
pod := &v1.Pod{}
200201

201-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
202+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
202203
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
203204

204205
assert.Nil(t, err)
@@ -296,7 +297,7 @@ func TestGetPodGroupMetadata_Tekton_PipelineRun(t *testing.T) {
296297
}
297298
pod := &v1.Pod{}
298299

299-
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey)
300+
defaultGrouper := NewDefaultGrouper(queueLabelKey, nodePoolLabelKey, fake.NewFakeClient())
300301
podGroupMetadata, err := defaultGrouper.GetPodGroupMetadata(owner, pod)
301302

302303
assert.Nil(t, err)

0 commit comments

Comments
 (0)