44package topology
55
66import (
7+ "errors"
78 "fmt"
89 "slices"
910 "sort"
11+ "strings"
1012
1113 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
1214 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1315 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
1416 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/resource_info"
1517 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
18+ "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/plugins/scores"
19+ "k8s.io/apimachinery/pkg/types"
20+ k8sframework "k8s.io/kubernetes/pkg/scheduler/framework"
1621)
1722
23+ type topologyStateData struct {
24+ relevantDomains []* TopologyDomainInfo
25+ }
26+
27+ func (t * topologyStateData ) Clone () k8sframework.StateData {
28+ return & topologyStateData {
29+ relevantDomains : t .relevantDomains ,
30+ }
31+ }
32+
1833type jobAllocationMetaData struct {
1934 maxPodResources * resource_info.ResourceRequirements
2035 allocationTestPods []* pod_info.PodInfo
@@ -30,6 +45,16 @@ func (t *topologyPlugin) prePredicateFn(_ *pod_info.PodInfo, job *podgroup_info.
3045 return nil
3146 }
3247
48+ // Check in cache if the job has already been allocated to a domain
49+ jobAllocatableDomains , err := t .loadAllocatableDomainsFromCache (types .UID (job .PodGroupUID ))
50+ if err != nil {
51+ return err
52+ }
53+ if len (jobAllocatableDomains ) > 0 {
54+ // Cache is already populated, no need to calculate anything
55+ return nil
56+ }
57+
3358 defer t .treeAllocatableCleanup (topologyTree )
3459 maxAllocatablePods , err := t .calcTreeAllocatable (job , topologyTree )
3560 if err != nil {
@@ -42,11 +67,18 @@ func (t *topologyPlugin) prePredicateFn(_ *pod_info.PodInfo, job *podgroup_info.
4267 return nil
4368 }
4469
45- _ , err = t .getBestJobAllocatableDomains (job , topologyTree )
70+ jobAllocatableDomain , err : = t .getBestJobAllocatableDomains (job , topologyTree )
4671 if err != nil {
4772 return err
4873 }
4974
75+ //Save results to cycle cache
76+ cycleJobState := (* k8sframework .CycleState )(t .sessionStateGetter .GetSessionStateForResource (job .PodGroupUID ))
77+ cycleJobState .Write (
78+ k8sframework .StateKey (topologyPluginName ),
79+ & topologyStateData {relevantDomains : jobAllocatableDomain },
80+ )
81+
5082 return nil
5183}
5284
@@ -278,3 +310,73 @@ func (*topologyPlugin) treeAllocatableCleanup(topologyTree *TopologyInfo) {
278310 }
279311 }
280312}
313+
314+ func (t * topologyPlugin ) predicateFn (pod * pod_info.PodInfo , job * podgroup_info.PodGroupInfo , node * node_info.NodeInfo ) error {
315+ jobAllocatableDomains , err := t .loadAllocatableDomainsFromCache (job .PodGroupUID )
316+ if err != nil {
317+ return err
318+ }
319+
320+ if len (jobAllocatableDomains ) > 0 {
321+ jobDomainsNames := []string {}
322+ for _ , domain := range jobAllocatableDomains {
323+ if domain .Nodes [node .Node .Name ] != nil {
324+ return nil
325+ }
326+ jobDomainsNames = append (jobDomainsNames , domain .Name )
327+ }
328+ return fmt .Errorf ("the node %s is not part of the chosen topology domain for the job %s. The chosen domains are %s" ,
329+ node .Node .Name , job .PodGroup .Name , strings .Join (jobDomainsNames , ", " ))
330+ }
331+
332+ return nil
333+ }
334+
335+ func (t * topologyPlugin ) nodeOrderFn (pod * pod_info.PodInfo , node * node_info.NodeInfo ) (float64 , error ) {
336+ score := 0.0
337+
338+ jobAllocatableDomains , err := t .loadAllocatableDomainsFromCache (types .UID (pod .Job ))
339+ if err != nil {
340+ return score , err
341+ }
342+
343+ if len (jobAllocatableDomains ) > 0 {
344+ for _ , domain := range jobAllocatableDomains {
345+ if domain .Nodes [node .Node .Name ] != nil {
346+ score = scores .Topology
347+ break
348+ }
349+ }
350+ }
351+
352+ return score , nil
353+ }
354+
355+ func (t * topologyPlugin ) loadAllocatableDomainsFromCache (podGroupUID types.UID ) ([]* TopologyDomainInfo , error ) {
356+ cycleJobState := (* k8sframework .CycleState )(t .sessionStateGetter .GetSessionStateForResource (podGroupUID ))
357+ if cycleJobState == nil {
358+ return nil , nil
359+ }
360+ jobTopologyStateData , err := cycleJobState .Read (k8sframework .StateKey (topologyPluginName ))
361+ if err != nil {
362+ if errors .Is (err , k8sframework .ErrNotFound ) {
363+ return nil , nil
364+ }
365+ return nil , err
366+ }
367+ jobAllocatableDomains := jobTopologyStateData .(* topologyStateData ).relevantDomains
368+ return jobAllocatableDomains , nil
369+ }
370+
371+ func (t * topologyPlugin ) cleanAllocationAttemptCache (job * podgroup_info.PodGroupInfo ) error {
372+ if job .PodGroup .Spec .TopologyConstraint .Topology == "" {
373+ return nil
374+ }
375+
376+ cycleJobState := (* k8sframework .CycleState )(t .sessionStateGetter .GetSessionStateForResource (job .PodGroupUID ))
377+ if cycleJobState == nil {
378+ return nil
379+ }
380+ cycleJobState .Delete (k8sframework .StateKey (topologyPluginName ))
381+ return nil
382+ }
0 commit comments