44package topology
55
66import (
7+ "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
78 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
8- "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
99 "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/framework"
10+ "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/k8s_internal"
1011 kueuev1alpha1 "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
1112)
1213
@@ -16,8 +17,11 @@ const (
1617)
1718
1819type topologyPlugin struct {
19- enabled bool
20- TopologyTrees map [string ]* TopologyInfo
20+ enabled bool
21+ taskOrderFunc common_info.LessFn
22+ sessionStateGetter k8s_internal.SessionStateProvider
23+ nodesInfos map [string ]* node_info.NodeInfo
24+ TopologyTrees map [string ]* TopologyInfo
2125}
2226
2327func New (pluginArgs map [string ]string ) framework.Plugin {
@@ -33,19 +37,17 @@ func (t *topologyPlugin) Name() string {
3337
3438func (t * topologyPlugin ) OnSessionOpen (ssn * framework.Session ) {
3539 topologies := ssn .Topologies
40+ t .taskOrderFunc = ssn .TaskOrderFn
41+ t .sessionStateGetter = ssn
42+ t .nodesInfos = ssn .Nodes
3643 t .initializeTopologyTree (topologies , ssn )
37-
38- ssn .AddEventHandler (& framework.EventHandler {
39- AllocateFunc : t .handleAllocate (ssn ),
40- DeallocateFunc : t .handleDeallocate (ssn ),
41- })
4244}
4345
4446func (t * topologyPlugin ) initializeTopologyTree (topologies []* kueuev1alpha1.Topology , ssn * framework.Session ) {
4547 for _ , singleTopology := range topologies {
4648 topologyTree := & TopologyInfo {
4749 Name : singleTopology .Name ,
48- Domains : map [TopologyDomainID ]* TopologyDomainInfo {},
50+ DomainsByLevel : map [ string ] map [TopologyDomainID ]* TopologyDomainInfo {},
4951 Root : NewTopologyDomainInfo (TopologyDomainID ("root" ), "datacenter" , "cluster" , 0 ),
5052 TopologyResource : singleTopology ,
5153 }
@@ -69,10 +71,16 @@ func (*topologyPlugin) addNodeDataToTopology(topologyTree *TopologyInfo, singleT
6971 }
7072
7173 domainId := calcDomainId (levelIndex , singleTopology .Spec .Levels , nodeInfo .Node .Labels )
72- domainInfo , foundLevelLabel := topologyTree .Domains [domainId ]
74+ domainLevel := level .NodeLabel
75+ domainsForLevel , foundLevelLabel := topologyTree .DomainsByLevel [domainLevel ]
7376 if ! foundLevelLabel {
77+ topologyTree .DomainsByLevel [level .NodeLabel ] = map [TopologyDomainID ]* TopologyDomainInfo {}
78+ domainsForLevel = topologyTree .DomainsByLevel [level .NodeLabel ]
79+ }
80+ domainInfo , foundDomain := domainsForLevel [domainId ]
81+ if ! foundDomain {
7482 domainInfo = NewTopologyDomainInfo (domainId , domainName , level .NodeLabel , levelIndex + 1 )
75- topologyTree . Domains [domainId ] = domainInfo
83+ domainsForLevel [domainId ] = domainInfo
7684 }
7785 domainInfo .AddNode (nodeInfo )
7886
@@ -86,48 +94,4 @@ func (*topologyPlugin) addNodeDataToTopology(topologyTree *TopologyInfo, singleT
8694 topologyTree .Root .AddNode (nodeInfo )
8795}
8896
89- func (t * topologyPlugin ) handleAllocate (ssn * framework.Session ) func (event * framework.Event ) {
90- return t .updateTopologyGivenPodEvent (ssn , func (domainInfo * TopologyDomainInfo , podInfo * pod_info.PodInfo ) {
91- domainInfo .AllocatedResources .AddResourceRequirements (podInfo .AcceptedResource )
92- domainInfo .AllocatedResources .BaseResource .ScalarResources ()["pods" ] =
93- domainInfo .AllocatedResources .BaseResource .ScalarResources ()["pods" ] + 1
94- })
95- }
96-
97- func (t * topologyPlugin ) handleDeallocate (ssn * framework.Session ) func (event * framework.Event ) {
98- return t .updateTopologyGivenPodEvent (ssn , func (domainInfo * TopologyDomainInfo , podInfo * pod_info.PodInfo ) {
99- domainInfo .AllocatedResources .SubResourceRequirements (podInfo .AcceptedResource )
100- domainInfo .AllocatedResources .BaseResource .ScalarResources ()["pods" ] =
101- domainInfo .AllocatedResources .BaseResource .ScalarResources ()["pods" ] - 1
102- })
103- }
104-
105- func (t * topologyPlugin ) updateTopologyGivenPodEvent (
106- ssn * framework.Session ,
107- domainUpdater func (domainInfo * TopologyDomainInfo , podInfo * pod_info.PodInfo ),
108- ) func (event * framework.Event ) {
109- return func (event * framework.Event ) {
110- pod := event .Task .Pod
111- nodeName := event .Task .NodeName
112- if nodeName == noNodeName {
113- return
114- }
115- node := ssn .Nodes [nodeName ].Node
116- podInfo := ssn .Nodes [nodeName ].PodInfos [pod_info .PodKey (pod )]
117-
118- for _ , topologyTree := range t .TopologyTrees {
119- leafDomainId := calcLeafDomainId (topologyTree .TopologyResource , node .Labels )
120- domainInfo := topologyTree .Domains [leafDomainId ]
121- for domainInfo != nil {
122- domainUpdater (domainInfo , podInfo )
123-
124- if domainInfo .Nodes [nodeName ] != nil {
125- break
126- }
127- domainInfo = domainInfo .Parent
128- }
129- }
130- }
131- }
132-
13397func (t * topologyPlugin ) OnSessionClose (ssn * framework.Session ) {}
0 commit comments