Skip to content

Commit 5938926

Browse files
authored
Changed SubSetNodes signature to use SubGroupInfo instead of SubGroupSet (#561)
1 parent 781fe28 commit 5938926

File tree

10 files changed

+56
-42
lines changed

10 files changed

+56
-42
lines changed

pkg/scheduler/actions/common/allocate.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ func AllocateJob(ssn *framework.Session, stmt *framework.Statement, nodes []*nod
2727
return false
2828
}
2929

30-
nodeSets, err := ssn.SubsetNodesFn(job, job.RootSubGroupSet, tasksToAllocate, nodes)
30+
podSets := job.RootSubGroupSet.GetAllPodSets()
31+
nodeSets, err := ssn.SubsetNodesFn(job, &job.RootSubGroupSet.SubGroupInfo, podSets, tasksToAllocate, nodes)
3132
if err != nil {
3233
log.InfraLogger.Errorf(
3334
"Failed to run SubsetNodes on job <%s/%s>: %v", job.Namespace, job.Namespace, err)

pkg/scheduler/api/podgroup_info/job_info.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func NewPodGroupInfo(uid common_info.PodGroupID, tasks ...*pod_info.PodInfo) *Po
116116
Stale: false,
117117
},
118118
RootSubGroupSet: defaultSubGroupSet,
119-
PodSets: defaultSubGroupSet.GetPodSets(),
119+
PodSets: defaultSubGroupSet.GetAllPodSets(),
120120

121121
LastStartTimestamp: nil,
122122
activeAllocatedCount: ptr.To(0),
@@ -199,7 +199,7 @@ func (pgi *PodGroupInfo) setSubGroups(podGroup *enginev2alpha2.PodGroup) error {
199199
return err
200200
}
201201
pgi.RootSubGroupSet = rootSubGroupSet
202-
podSets := rootSubGroupSet.GetPodSets()
202+
podSets := rootSubGroupSet.GetAllPodSets()
203203
if len(podSets) > 0 {
204204
pgi.PodSets = podSets
205205
} else {
@@ -480,7 +480,7 @@ func (pgi *PodGroupInfo) CloneWithTasks(tasks []*pod_info.PodInfo) *PodGroupInfo
480480
pgi.CreationTimestamp.DeepCopyInto(&info.CreationTimestamp)
481481

482482
info.RootSubGroupSet = pgi.RootSubGroupSet.Clone()
483-
info.PodSets = info.RootSubGroupSet.GetPodSets()
483+
info.PodSets = info.RootSubGroupSet.GetAllPodSets()
484484

485485
for _, task := range tasks {
486486
info.AddTaskInfo(task.Clone())

pkg/scheduler/api/podgroup_info/subgroup_info/subgroupset.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ func (sgs *SubGroupSet) Clone() *SubGroupSet {
6363
return root
6464
}
6565

66-
func (sgs *SubGroupSet) GetPodSets() map[string]*PodSet {
66+
func (sgs *SubGroupSet) GetAllPodSets() map[string]*PodSet {
6767
result := make(map[string]*PodSet)
6868
for _, podSet := range sgs.podSets {
6969
result[podSet.GetName()] = podSet
7070
}
7171
for _, subGroup := range sgs.groups {
72-
podSets := subGroup.GetPodSets()
72+
podSets := subGroup.GetAllPodSets()
7373
for name, podSet := range podSets {
7474
result[name] = podSet
7575
}

pkg/scheduler/api/podgroup_info/subgroup_info/subgroupset_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func TestGetPodSets(t *testing.T) {
175175
for _, tt := range tests {
176176
t.Run(tt.name, func(t *testing.T) {
177177
sg, want := tt.build()
178-
got := sg.GetPodSets()
178+
got := sg.GetAllPodSets()
179179
if len(got) != len(want) {
180180
t.Errorf("expected %d podsets, got %d", len(want), len(got))
181181
}

pkg/scheduler/api/types.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ import (
3030
)
3131

3232
// SubsetNodesFn is used to divide the nodes into sets
33-
type SubsetNodesFn func(podGroup *podgroup_info.PodGroupInfo, subGroupSet *subgroup_info.SubGroupSet, tasks []*pod_info.PodInfo, nodeSet node_info.NodeSet) ([]node_info.NodeSet, error)
33+
type SubsetNodesFn func(podGroup *podgroup_info.PodGroupInfo, subGroup *subgroup_info.SubGroupInfo,
34+
podSets map[string]*subgroup_info.PodSet, tasks []*pod_info.PodInfo, nodeSet node_info.NodeSet) ([]node_info.NodeSet, error)
3435

3536
// PredicateFn is used to predicate node for task.
3637
type PredicateFn func(*pod_info.PodInfo, *podgroup_info.PodGroupInfo, *node_info.NodeInfo) error

pkg/scheduler/framework/session_plugins.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ func (ssn *Session) SubGroupOrderFn(l, r interface{}) bool {
264264
return lSubGroup.GetName() < rSubGroup.GetName()
265265
}
266266

267-
func (ssn *Session) QueueOrderFn(lQ, rQ *queue_info.QueueInfo, lJob, rJob *podgroup_info.PodGroupInfo, lVictims, rVictims []*podgroup_info.PodGroupInfo) bool {
267+
func (ssn *Session) QueueOrderFn(lQ, rQ *queue_info.QueueInfo, lJob, rJob *podgroup_info.PodGroupInfo,
268+
lVictims, rVictims []*podgroup_info.PodGroupInfo,
269+
) bool {
268270
for _, qof := range ssn.QueueOrderFns {
269271
if j := qof(lQ, rQ, lJob, rJob, lVictims, rVictims); j != 0 {
270272
return j < 0
@@ -322,14 +324,17 @@ func (ssn *Session) IsTaskAllocationOnNodeOverCapacityFn(task *pod_info.PodInfo,
322324
}
323325
}
324326

325-
func (ssn *Session) SubsetNodesFn(podGroup *podgroup_info.PodGroupInfo, subGroupSet *subgroup_info.SubGroupSet, tasks []*pod_info.PodInfo, initNodeSet node_info.NodeSet) ([]node_info.NodeSet, error) {
327+
func (ssn *Session) SubsetNodesFn(
328+
podGroup *podgroup_info.PodGroupInfo, subGroupInfo *subgroup_info.SubGroupInfo,
329+
podSets map[string]*subgroup_info.PodSet, tasks []*pod_info.PodInfo, initNodeSet node_info.NodeSet,
330+
) ([]node_info.NodeSet, error) {
326331
nodeSets := []node_info.NodeSet{initNodeSet}
327332
for _, subsetNodesFn := range ssn.SubsetNodesFns {
328333
log.InfraLogger.V(7).Infof(
329334
"Running plugin func <%v> on podGroup <%s/%s>", subsetNodesFn, podGroup.Namespace, podGroup.Namespace)
330335
var newNodeSets []node_info.NodeSet
331336
for _, nodeSet := range nodeSets {
332-
nodeSubsets, err := subsetNodesFn(podGroup, subGroupSet, tasks, nodeSet)
337+
nodeSubsets, err := subsetNodesFn(podGroup, subGroupInfo, podSets, tasks, nodeSet)
333338
if err != nil {
334339
return nil, err
335340
}

pkg/scheduler/framework/session_plugins_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestPartitionMultiImplementation(t *testing.T) {
113113
},
114114
}
115115

116-
shardClusterSubseting := func(_ *podgroup_info.PodGroupInfo, _ *subgroup_info.SubGroupSet, _ []*pod_info.PodInfo, nodeset node_info.NodeSet) ([]node_info.NodeSet, error) {
116+
shardClusterSubseting := func(_ *podgroup_info.PodGroupInfo, _ *subgroup_info.SubGroupInfo, _ map[string]*subgroup_info.PodSet, _ []*pod_info.PodInfo, nodeset node_info.NodeSet) ([]node_info.NodeSet, error) {
117117
var subset1 []*node_info.NodeInfo
118118
var subset2 []*node_info.NodeInfo
119119
for _, node := range nodeset {
@@ -126,7 +126,7 @@ func TestPartitionMultiImplementation(t *testing.T) {
126126
return []node_info.NodeSet{subset1, subset2}, nil
127127
}
128128

129-
topologySubseting := func(_ *podgroup_info.PodGroupInfo, _ *subgroup_info.SubGroupSet, _ []*pod_info.PodInfo, nodeset node_info.NodeSet) ([]node_info.NodeSet, error) {
129+
topologySubseting := func(_ *podgroup_info.PodGroupInfo, _ *subgroup_info.SubGroupInfo, _ map[string]*subgroup_info.PodSet, _ []*pod_info.PodInfo, nodeset node_info.NodeSet) ([]node_info.NodeSet, error) {
130130
var subset1 []*node_info.NodeInfo
131131
var subset2 []*node_info.NodeInfo
132132
for _, node := range nodeset {
@@ -144,7 +144,7 @@ func TestPartitionMultiImplementation(t *testing.T) {
144144
ssn.AddSubsetNodesFn(shardClusterSubseting)
145145
ssn.AddSubsetNodesFn(topologySubseting)
146146

147-
partitions, _ := ssn.SubsetNodesFn(podgroup_info.NewPodGroupInfo("a"), nil, nil, nodes)
147+
partitions, _ := ssn.SubsetNodesFn(podgroup_info.NewPodGroupInfo("a"), nil, nil, nil, nodes)
148148

149149
assert.Equal(t, 4, len(partitions))
150150

pkg/scheduler/plugins/topology/job_filtering.go

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ type jobAllocationMetaData struct {
2121
}
2222

2323
func (t *topologyPlugin) subSetNodesFn(
24-
job *podgroup_info.PodGroupInfo, subGroupSet *subgroup_info.SubGroupSet, tasks []*pod_info.PodInfo, nodeSet node_info.NodeSet,
24+
job *podgroup_info.PodGroupInfo, subGroup *subgroup_info.SubGroupInfo, podSets map[string]*subgroup_info.PodSet,
25+
tasks []*pod_info.PodInfo, nodeSet node_info.NodeSet,
2526
) ([]node_info.NodeSet, error) {
26-
topologyTree, found := t.getJobTopology(subGroupSet)
27+
topologyTree, found := t.getJobTopology(subGroup)
2728
if !found {
2829
job.SetJobFitError(
2930
podgroup_info.PodSchedulingErrors,
3031
fmt.Sprintf("Matching topology %s does not exist",
31-
subGroupSet.GetTopologyConstraint().Topology),
32+
subGroup.GetTopologyConstraint().Topology),
3233
nil)
3334
return []node_info.NodeSet{}, nil
3435
}
@@ -50,7 +51,7 @@ func (t *topologyPlugin) subSetNodesFn(
5051
return []node_info.NodeSet{}, nil
5152
}
5253

53-
jobAllocatableDomains, err := t.getJobAllocatableDomains(job, subGroupSet, len(tasks), topologyTree)
54+
jobAllocatableDomains, err := t.getJobAllocatableDomains(job, subGroup, podSets, len(tasks), topologyTree)
5455
if err != nil {
5556
return nil, err
5657
}
@@ -67,11 +68,11 @@ func (t *topologyPlugin) subSetNodesFn(
6768
return domainNodeSets, nil
6869
}
6970

70-
func (t *topologyPlugin) getJobTopology(subGroupSet *subgroup_info.SubGroupSet) (*Info, bool) {
71-
if subGroupSet.GetTopologyConstraint() == nil {
71+
func (t *topologyPlugin) getJobTopology(subGroup *subgroup_info.SubGroupInfo) (*Info, bool) {
72+
if subGroup.GetTopologyConstraint() == nil {
7273
return nil, true
7374
}
74-
jobTopologyName := subGroupSet.GetTopologyConstraint().Topology
75+
jobTopologyName := subGroup.GetTopologyConstraint().Topology
7576
if jobTopologyName == "" {
7677
return nil, true
7778
}
@@ -188,18 +189,19 @@ func calcNextAllocationTestPodResources(previousTestResources, maxPodResources *
188189
}
189190

190191
func (t *topologyPlugin) getJobAllocatableDomains(
191-
job *podgroup_info.PodGroupInfo, subGroupSet *subgroup_info.SubGroupSet, taskToAllocateCount int, topologyTree *Info,
192+
job *podgroup_info.PodGroupInfo, subGroup *subgroup_info.SubGroupInfo, podSets map[string]*subgroup_info.PodSet,
193+
taskToAllocateCount int, topologyTree *Info,
192194
) ([]*DomainInfo, error) {
193-
relevantLevels, err := t.calculateRelevantDomainLevels(subGroupSet, topologyTree)
195+
relevantLevels, err := t.calculateRelevantDomainLevels(subGroup, topologyTree)
194196
if err != nil {
195197
return nil, err
196198
}
197199

198200
// Validate that the domains do not clash with the chosen domain for active pods of the job
199201
var relevantDomainsByLevel domainsByLevel
200-
if hasActiveAllocatedTasks(subGroupSet) && hasTopologyRequiredConstraint(subGroupSet) {
201-
relevantDomainsByLevel = getRelevantDomainsWithAllocatedPods(subGroupSet, topologyTree,
202-
DomainLevel(subGroupSet.GetTopologyConstraint().RequiredLevel))
202+
if hasActiveAllocatedTasks(podSets) && hasTopologyRequiredConstraint(subGroup) {
203+
relevantDomainsByLevel = getRelevantDomainsWithAllocatedPods(podSets, topologyTree,
204+
DomainLevel(subGroup.GetTopologyConstraint().RequiredLevel))
203205
} else {
204206
relevantDomainsByLevel = topologyTree.DomainsByLevel
205207
}
@@ -223,8 +225,8 @@ func (t *topologyPlugin) getJobAllocatableDomains(
223225
return domains, nil
224226
}
225227

226-
func hasActiveAllocatedTasks(subGroupSet *subgroup_info.SubGroupSet) bool {
227-
for _, podSet := range subGroupSet.GetPodSets() {
228+
func hasActiveAllocatedTasks(podSets map[string]*subgroup_info.PodSet) bool {
229+
for _, podSet := range podSets {
228230
if podSet.GetNumActiveAllocatedTasks() > 0 {
229231
return true
230232
}
@@ -233,20 +235,20 @@ func hasActiveAllocatedTasks(subGroupSet *subgroup_info.SubGroupSet) bool {
233235
}
234236

235237
func getRelevantDomainsWithAllocatedPods(
236-
subGroupSet *subgroup_info.SubGroupSet, topologyTree *Info, requiredLevel DomainLevel,
238+
podSets map[string]*subgroup_info.PodSet, topologyTree *Info, requiredLevel DomainLevel,
237239
) domainsByLevel {
238240
relevantDomainsByLevel := domainsByLevel{}
239241
for _, domainAtRequiredLevel := range topologyTree.DomainsByLevel[requiredLevel] {
240-
if !hasActiveJobPodInDomain(subGroupSet, domainAtRequiredLevel) {
242+
if !hasActiveJobPodInDomain(podSets, domainAtRequiredLevel) {
241243
continue // if the domain at the top level does not have any active pods, then any domains under the subtree cannot satisfy the required constraint for both active and pending pods
242244
}
243245
addSubTreeToDomainMap(domainAtRequiredLevel, relevantDomainsByLevel)
244246
}
245247
return relevantDomainsByLevel
246248
}
247249

248-
func hasActiveJobPodInDomain(subGroupSet *subgroup_info.SubGroupSet, domain *DomainInfo) bool {
249-
for _, podSet := range subGroupSet.GetPodSets() {
250+
func hasActiveJobPodInDomain(podSets map[string]*subgroup_info.PodSet, domain *DomainInfo) bool {
251+
for _, podSet := range podSets {
250252
for _, pod := range podSet.GetPodInfos() {
251253
if pod_status.IsActiveAllocatedStatus(pod.Status) {
252254
podInDomain := domain.Nodes[pod.NodeName] != nil
@@ -269,19 +271,19 @@ func addSubTreeToDomainMap(domain *DomainInfo, domainsMap domainsByLevel) {
269271
domainsMap[domain.Level][domain.ID] = domain
270272
}
271273

272-
func hasTopologyRequiredConstraint(subGroupSet *subgroup_info.SubGroupSet) bool {
273-
return subGroupSet.GetTopologyConstraint().RequiredLevel != ""
274+
func hasTopologyRequiredConstraint(subGroup *subgroup_info.SubGroupInfo) bool {
275+
return subGroup.GetTopologyConstraint().RequiredLevel != ""
274276
}
275277

276278
func (*topologyPlugin) calculateRelevantDomainLevels(
277-
subGroupSet *subgroup_info.SubGroupSet, topologyTree *Info,
279+
subGroup *subgroup_info.SubGroupInfo, topologyTree *Info,
278280
) ([]DomainLevel, error) {
279-
topologyConstraint := subGroupSet.GetTopologyConstraint()
281+
topologyConstraint := subGroup.GetTopologyConstraint()
280282
requiredPlacement := DomainLevel(topologyConstraint.RequiredLevel)
281283
preferredPlacement := DomainLevel(topologyConstraint.PreferredLevel)
282284
if requiredPlacement == "" && preferredPlacement == "" {
283285
return nil, fmt.Errorf("no topology constraints were found for subgroup %s, with topology name %s",
284-
subGroupSet.GetName(), topologyTree.Name)
286+
subGroup.GetName(), topologyTree.Name)
285287
}
286288

287289
foundRequiredLevel := false

pkg/scheduler/plugins/topology/job_filtering_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ func TestTopologyPlugin_subsetNodesFn(t *testing.T) {
390390
}
391391

392392
// Call the function under test
393-
subsets, err := plugin.subSetNodesFn(job, job.RootSubGroupSet, podgroup_info.GetTasksToAllocate(job, nil, nil, true), maps.Values(nodesInfoMap))
393+
subsets, err := plugin.subSetNodesFn(job, &job.RootSubGroupSet.SubGroupInfo,
394+
job.RootSubGroupSet.GetAllPodSets(), podgroup_info.GetTasksToAllocate(job, nil, nil, true),
395+
maps.Values(nodesInfoMap))
394396

395397
// Check error
396398
if tt.expectedError != "" {
@@ -723,7 +725,7 @@ func TestTopologyPlugin_calculateRelevantDomainLevels(t *testing.T) {
723725
t.Run(tt.name, func(t *testing.T) {
724726
plugin := &topologyPlugin{}
725727

726-
result, err := plugin.calculateRelevantDomainLevels(tt.subGroupSet, tt.topologyTree)
728+
result, err := plugin.calculateRelevantDomainLevels(&tt.subGroupSet.SubGroupInfo, tt.topologyTree)
727729

728730
// Check error
729731
if tt.expectedError != "" {
@@ -1081,7 +1083,7 @@ func TestTopologyPlugin_calcTreeAllocatable(t *testing.T) {
10811083
},
10821084
},
10831085
{
1084-
name: "no leaf domains - no allocateable domains",
1086+
name: "no leaf domains - no allocatable domains",
10851087
job: &jobs_fake.TestJobBasic{
10861088
Name: "test-job",
10871089
RequiredCPUsPerTask: 2000, // Too much for any node
@@ -1840,7 +1842,9 @@ func TestTopologyPlugin_getJobAllocatableDomains(t *testing.T) {
18401842
for _, podSet := range tt.job.PodSets {
18411843
tt.job.RootSubGroupSet.AddPodSet(podSet)
18421844
}
1843-
result, err := plugin.getJobAllocatableDomains(tt.job, tt.job.RootSubGroupSet, len(podgroup_info.GetTasksToAllocate(tt.job, nil, nil, true)), tt.topologyTree)
1845+
result, err := plugin.getJobAllocatableDomains(tt.job, &tt.job.RootSubGroupSet.SubGroupInfo,
1846+
tt.job.RootSubGroupSet.GetAllPodSets(), len(podgroup_info.GetTasksToAllocate(tt.job, nil, nil, true)),
1847+
tt.topologyTree)
18441848

18451849
// Check error
18461850
if tt.expectedError != "" {

pkg/scheduler/test_utils/test_utils.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"k8s.io/client-go/informers"
1919
"k8s.io/client-go/kubernetes/fake"
2020

21+
kueuev1alpha1 "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
22+
2123
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/actions"
2224
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
2325
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_status"
@@ -31,7 +33,6 @@ import (
3133
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils/dra_fake"
3234
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils/jobs_fake"
3335
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils/nodes_fake"
34-
kueuev1alpha1 "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
3536
)
3637

3738
var SchedulerVerbosity = flag.String("vv", "", "Scheduler's verbosity")

0 commit comments

Comments
 (0)