Skip to content

Commit 448dcc4

Browse files
authored
predicate and nodeOrder for the topology plugin + topology result caching (#356)
1 parent da36a65 commit 448dcc4

File tree

9 files changed

+145
-23
lines changed

9 files changed

+145
-23
lines changed

pkg/scheduler/actions/common/allocate.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ func allocateTask(ssn *framework.Session, stmt *framework.Statement, nodes []*no
7373
task.Namespace, task.Name, node.Name)
7474
}
7575

76+
ssn.CleanAllocationAttemptCache(job)
77+
7678
if success {
7779
log.InfraLogger.V(6).Infof("Allocation succeeded for task: <%v/%v>", task.Namespace, task.Name)
7880
} else {

pkg/scheduler/api/types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ type OnJobSolutionStartFn func()
6868
// BindRequestMutateFn allows plugins to add annotations before BindRequest creation.
6969
type BindRequestMutateFn func(pod *pod_info.PodInfo, nodeName string) map[string]string
7070

71+
// CleanAllocationAttemptCacheFn is used to clean the cycle cache after an allocation attempt for a job.
72+
type CleanAllocationAttemptCacheFn func(job *podgroup_info.PodGroupInfo) error
73+
7174
type SchedulableResult struct {
7275
IsSchedulable bool
7376
Reason v2alpha2.UnschedulableReason

pkg/scheduler/framework/session.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,24 @@ type Session struct {
8181
PrePredicateFns []api.PrePredicateFn
8282
PredicateFns []api.PredicateFn
8383
BindRequestMutateFns []api.BindRequestMutateFn
84+
CleanAllocationAttemptCacheFns []api.CleanAllocationAttemptCacheFn
8485

8586
Config *conf.SchedulerConfiguration
8687
plugins map[string]Plugin
8788
eventHandlers []*EventHandler
8889
SchedulerParams conf.SchedulerParams
8990
mux *http.ServeMux
9091

91-
k8sPodState map[types.UID]k8s_internal.SessionState
92+
k8sResourceStateCache sync.Map
9293
}
9394

9495
func (ssn *Session) Statement() *Statement {
9596
return &Statement{ssn: ssn, sessionUID: ssn.UID}
9697
}
9798

98-
func (ssn *Session) GetK8sStateForPod(uid types.UID) k8s_internal.SessionState {
99-
if ssn.k8sPodState == nil {
100-
ssn.k8sPodState = make(map[types.UID]k8s_internal.SessionState)
101-
}
102-
state, found := ssn.k8sPodState[uid]
103-
if found {
104-
return state
105-
}
106-
ssn.k8sPodState[uid] = k8s_internal.NewSessionState()
107-
return ssn.k8sPodState[uid]
99+
func (ssn *Session) GetSessionStateForResource(uid types.UID) k8s_internal.SessionState {
100+
state, _ := ssn.k8sResourceStateCache.LoadOrStore(uid, k8s_internal.NewSessionState())
101+
return state.(k8s_internal.SessionState)
108102
}
109103

110104
func (ssn *Session) BindPod(pod *pod_info.PodInfo) error {
@@ -346,10 +340,10 @@ func openSession(cache cache.Cache, sessionId types.UID, schedulerParams conf.Sc
346340
Queues: map[common_info.QueueID]*queue_info.QueueInfo{},
347341
Topologies: []*kueuev1alpha1.Topology{},
348342

349-
plugins: map[string]Plugin{},
350-
SchedulerParams: schedulerParams,
351-
mux: mux,
352-
k8sPodState: map[types.UID]k8s_internal.SessionState{},
343+
plugins: map[string]Plugin{},
344+
SchedulerParams: schedulerParams,
345+
mux: mux,
346+
k8sResourceStateCache: sync.Map{},
353347
}
354348

355349
log.InfraLogger.V(2).Infof("Taking cluster snapshot ...")

pkg/scheduler/framework/session_plugins.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ func (ssn *Session) NodeOrderFn(task *pod_info.PodInfo, node *node_info.NodeInfo
375375
return priorityScore, nil
376376
}
377377

378+
func (ssn *Session) AddCleanAllocationAttemptCacheFn(fn api.CleanAllocationAttemptCacheFn) {
379+
ssn.CleanAllocationAttemptCacheFns = append(ssn.CleanAllocationAttemptCacheFns, fn)
380+
}
381+
382+
func (ssn *Session) CleanAllocationAttemptCache(job *podgroup_info.PodGroupInfo) {
383+
for _, cleaner := range ssn.CleanAllocationAttemptCacheFns {
384+
err := cleaner(job)
385+
if err != nil {
386+
log.InfraLogger.V(6).Infof(
387+
"Failed to run CleanAllocationAttemptCache on job %s", job.Name)
388+
}
389+
}
390+
}
391+
378392
func (ssn *Session) IsRestrictNodeSchedulingEnabled() bool {
379393
return ssn.SchedulerParams.RestrictSchedulingNodes
380394
}

pkg/scheduler/k8s_internal/k8s_internal.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func FitPrePredicateConverter(
5252
nodePreFilter NodePreFilter,
5353
) FitPredicatePreFilter {
5454
return func(pod *v1.Pod) (sets.Set[string], *k8sframework.Status) {
55-
state := stateProvider.GetK8sStateForPod(pod.UID)
55+
state := stateProvider.GetSessionStateForResource(pod.UID)
5656
result, status := nodePreFilter.PreFilter(context.Background(), state, pod)
5757
if status != nil {
5858
return nil, status
@@ -69,7 +69,7 @@ func FitPredicateConverter(
6969
nodeFilter NodeFilter,
7070
) FitPredicateFilter {
7171
return func(pod *v1.Pod, nodeInfo *k8sframework.NodeInfo) (bool, []string, error) {
72-
state := stateProvider.GetK8sStateForPod(pod.UID)
72+
state := stateProvider.GetSessionStateForResource(pod.UID)
7373
result := nodeFilter.Filter(context.Background(), state, pod, nodeInfo)
7474
if result == nil {
7575
return true, nil, nil
@@ -83,7 +83,7 @@ func PreScorePluginConverter(
8383
nodeScorer ExtendedNodeScorer,
8484
) PreScoreFn {
8585
return func(pod *v1.Pod, fittingNodes []*k8sframework.NodeInfo) *k8sframework.Status {
86-
state := stateProvider.GetK8sStateForPod(pod.UID)
86+
state := stateProvider.GetSessionStateForResource(pod.UID)
8787
status := nodeScorer.PreScore(context.Background(), state, pod, fittingNodes)
8888
return status
8989
}
@@ -94,7 +94,7 @@ func ScorePluginConverter(
9494
nodeScorer NodeScorer,
9595
) ScorePredicate {
9696
return func(pod *v1.Pod, nodeInfo *k8sframework.NodeInfo) (int64, []string, error) {
97-
state := stateProvider.GetK8sStateForPod(pod.UID)
97+
state := stateProvider.GetSessionStateForResource(pod.UID)
9898

9999
score, result := nodeScorer.Score(context.Background(), state, pod, nodeInfo.Node().Name)
100100
if result == nil {

pkg/scheduler/k8s_internal/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ type SessionScoreFns struct {
5858
type SessionState *k8sframework.CycleState
5959

6060
type SessionStateProvider interface {
61-
GetK8sStateForPod(podUID types.UID) SessionState
61+
GetSessionStateForResource(podUID types.UID) SessionState
6262
}

pkg/scheduler/plugins/scores/scores.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const (
88
ResourceType = 10
99
Availability = 100
1010
GpuSharing = 1000
11-
K8sPlugins = 10000
12-
NominatedNode = 100000
11+
Topology = 10000
12+
K8sPlugins = 100000
13+
NominatedNode = 1000000
1314
)

pkg/scheduler/plugins/topology/topology_plugin.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ func (t *topologyPlugin) OnSessionOpen(ssn *framework.Session) {
4646

4747
//pre-predicate to generate the whole topology tree and store per workload
4848
ssn.AddPrePredicateFn(t.prePredicateFn)
49+
//predicate to filter nodes that are related to parts of the tree that cannot accommodate the workload - this is for "required" use only
50+
ssn.AddPredicateFn(t.predicateFn)
51+
//node order to sort the nodes according to topology nodes score - this is for "prefer" use only
52+
ssn.AddNodeOrderFn(t.nodeOrderFn)
53+
//clean cycle cache after an allocation attempt for a job
54+
ssn.AddCleanAllocationAttemptCacheFn(t.cleanAllocationAttemptCache)
4955
}
5056

5157
func (t *topologyPlugin) initializeTopologyTree(topologies []*kueuev1alpha1.Topology, ssn *framework.Session) {

pkg/scheduler/plugins/topology/topology_plugin_job_filtering.go

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,32 @@
44
package topology
55

66
import (
7+
"errors"
78
"fmt"
89
"slices"
910
"sort"
11+
"strings"
1012

1113
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
1214
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1315
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
1416
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/resource_info"
1517
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
18+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/plugins/scores"
19+
"k8s.io/apimachinery/pkg/types"
20+
k8sframework "k8s.io/kubernetes/pkg/scheduler/framework"
1621
)
1722

23+
type topologyStateData struct {
24+
relevantDomains []*TopologyDomainInfo
25+
}
26+
27+
func (t *topologyStateData) Clone() k8sframework.StateData {
28+
return &topologyStateData{
29+
relevantDomains: t.relevantDomains,
30+
}
31+
}
32+
1833
type jobAllocationMetaData struct {
1934
maxPodResources *resource_info.ResourceRequirements
2035
allocationTestPods []*pod_info.PodInfo
@@ -30,6 +45,16 @@ func (t *topologyPlugin) prePredicateFn(_ *pod_info.PodInfo, job *podgroup_info.
3045
return nil
3146
}
3247

48+
// Check in cache if the job has already been allocated to a domain
49+
jobAllocatableDomains, err := t.loadAllocatableDomainsFromCache(types.UID(job.PodGroupUID))
50+
if err != nil {
51+
return err
52+
}
53+
if len(jobAllocatableDomains) > 0 {
54+
// Cache is already populated, no need to calculate anything
55+
return nil
56+
}
57+
3358
defer t.treeAllocatableCleanup(topologyTree)
3459
maxAllocatablePods, err := t.calcTreeAllocatable(job, topologyTree)
3560
if err != nil {
@@ -42,11 +67,18 @@ func (t *topologyPlugin) prePredicateFn(_ *pod_info.PodInfo, job *podgroup_info.
4267
return nil
4368
}
4469

45-
_, err = t.getBestJobAllocatableDomains(job, topologyTree)
70+
jobAllocatableDomain, err := t.getBestJobAllocatableDomains(job, topologyTree)
4671
if err != nil {
4772
return err
4873
}
4974

75+
//Save results to cycle cache
76+
cycleJobState := (*k8sframework.CycleState)(t.sessionStateGetter.GetSessionStateForResource(job.PodGroupUID))
77+
cycleJobState.Write(
78+
k8sframework.StateKey(topologyPluginName),
79+
&topologyStateData{relevantDomains: jobAllocatableDomain},
80+
)
81+
5082
return nil
5183
}
5284

@@ -278,3 +310,73 @@ func (*topologyPlugin) treeAllocatableCleanup(topologyTree *TopologyInfo) {
278310
}
279311
}
280312
}
313+
314+
func (t *topologyPlugin) predicateFn(pod *pod_info.PodInfo, job *podgroup_info.PodGroupInfo, node *node_info.NodeInfo) error {
315+
jobAllocatableDomains, err := t.loadAllocatableDomainsFromCache(job.PodGroupUID)
316+
if err != nil {
317+
return err
318+
}
319+
320+
if len(jobAllocatableDomains) > 0 {
321+
jobDomainsNames := []string{}
322+
for _, domain := range jobAllocatableDomains {
323+
if domain.Nodes[node.Node.Name] != nil {
324+
return nil
325+
}
326+
jobDomainsNames = append(jobDomainsNames, domain.Name)
327+
}
328+
return fmt.Errorf("the node %s is not part of the chosen topology domain for the job %s. The chosen domains are %s",
329+
node.Node.Name, job.PodGroup.Name, strings.Join(jobDomainsNames, ", "))
330+
}
331+
332+
return nil
333+
}
334+
335+
func (t *topologyPlugin) nodeOrderFn(pod *pod_info.PodInfo, node *node_info.NodeInfo) (float64, error) {
336+
score := 0.0
337+
338+
jobAllocatableDomains, err := t.loadAllocatableDomainsFromCache(types.UID(pod.Job))
339+
if err != nil {
340+
return score, err
341+
}
342+
343+
if len(jobAllocatableDomains) > 0 {
344+
for _, domain := range jobAllocatableDomains {
345+
if domain.Nodes[node.Node.Name] != nil {
346+
score = scores.Topology
347+
break
348+
}
349+
}
350+
}
351+
352+
return score, nil
353+
}
354+
355+
func (t *topologyPlugin) loadAllocatableDomainsFromCache(podGroupUID types.UID) ([]*TopologyDomainInfo, error) {
356+
cycleJobState := (*k8sframework.CycleState)(t.sessionStateGetter.GetSessionStateForResource(podGroupUID))
357+
if cycleJobState == nil {
358+
return nil, nil
359+
}
360+
jobTopologyStateData, err := cycleJobState.Read(k8sframework.StateKey(topologyPluginName))
361+
if err != nil {
362+
if errors.Is(err, k8sframework.ErrNotFound) {
363+
return nil, nil
364+
}
365+
return nil, err
366+
}
367+
jobAllocatableDomains := jobTopologyStateData.(*topologyStateData).relevantDomains
368+
return jobAllocatableDomains, nil
369+
}
370+
371+
func (t *topologyPlugin) cleanAllocationAttemptCache(job *podgroup_info.PodGroupInfo) error {
372+
if job.PodGroup.Spec.TopologyConstraint.Topology == "" {
373+
return nil
374+
}
375+
376+
cycleJobState := (*k8sframework.CycleState)(t.sessionStateGetter.GetSessionStateForResource(job.PodGroupUID))
377+
if cycleJobState == nil {
378+
return nil
379+
}
380+
cycleJobState.Delete(k8sframework.StateKey(topologyPluginName))
381+
return nil
382+
}

0 commit comments

Comments
 (0)