Skip to content

Commit 40c65b0

Browse files
committed
Added statuses cache to SubgroupInfo
1 parent d5c1aa0 commit 40c65b0

File tree

2 files changed

+117
-48
lines changed

2 files changed

+117
-48
lines changed

pkg/scheduler/api/podgroup_info/subgroup_info.go

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,32 @@ package podgroup_info
55

66
import (
77
"github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
8+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
89
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
910
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_status"
1011
)
1112

1213
type SubGroupInfo struct {
13-
name string
14-
minAvailable int32
15-
podInfos pod_info.PodsMap
14+
name string
15+
minAvailable int32
16+
podInfos pod_info.PodsMap
17+
podStatusIndex map[pod_status.PodStatus]pod_info.PodsMap
18+
podStatusMap map[common_info.PodID]pod_status.PodStatus
19+
numActiveAllocatedTasks int
20+
numActiveUsedTasks int
21+
numAliveTasks int
1622
}
1723

1824
func NewSubGroupInfo(name string, minAvailable int32) *SubGroupInfo {
1925
return &SubGroupInfo{
20-
name: name,
21-
minAvailable: minAvailable,
22-
podInfos: pod_info.PodsMap{},
26+
name: name,
27+
minAvailable: minAvailable,
28+
podInfos: pod_info.PodsMap{},
29+
podStatusIndex: map[pod_status.PodStatus]pod_info.PodsMap{},
30+
podStatusMap: map[common_info.PodID]pod_status.PodStatus{},
31+
numActiveAllocatedTasks: 0,
32+
numActiveUsedTasks: 0,
33+
numAliveTasks: 0,
2334
}
2435
}
2536

@@ -44,11 +55,55 @@ func (sgi *SubGroupInfo) GetPodInfos() pod_info.PodsMap {
4455
}
4556

4657
func (sgi *SubGroupInfo) AssignTask(ti *pod_info.PodInfo) {
58+
sgi.clearOldStatus(ti)
59+
60+
if _, found := sgi.podStatusIndex[ti.Status]; !found {
61+
sgi.podStatusIndex[ti.Status] = pod_info.PodsMap{}
62+
}
63+
sgi.podStatusIndex[ti.Status][ti.UID] = ti
64+
65+
if pod_status.IsActiveAllocatedStatus(ti.Status) {
66+
sgi.numActiveAllocatedTasks += 1
67+
}
68+
if pod_status.IsActiveUsedStatus(ti.Status) {
69+
sgi.numActiveUsedTasks += 1
70+
}
71+
if pod_status.IsAliveStatus(ti.Status) {
72+
sgi.numAliveTasks += 1
73+
}
74+
75+
sgi.podStatusMap[ti.UID] = ti.Status
4776
sgi.podInfos[ti.UID] = ti
4877
}
4978

50-
func (sgi *SubGroupInfo) WithPodInfos(podInfos pod_info.PodsMap) *SubGroupInfo {
51-
sgi.podInfos = podInfos
79+
func (sgi *SubGroupInfo) clearOldStatus(ti *pod_info.PodInfo) {
80+
oldStatus, found := sgi.podStatusMap[ti.UID]
81+
if !found {
82+
return
83+
}
84+
if pod_status.IsActiveAllocatedStatus(oldStatus) {
85+
sgi.numActiveAllocatedTasks -= 1
86+
}
87+
if pod_status.IsActiveUsedStatus(oldStatus) {
88+
sgi.numActiveUsedTasks -= 1
89+
}
90+
if pod_status.IsAliveStatus(oldStatus) {
91+
sgi.numAliveTasks -= 1
92+
}
93+
94+
delete(sgi.podStatusIndex[oldStatus], ti.UID)
95+
delete(sgi.podStatusMap, ti.UID)
96+
delete(sgi.podInfos, ti.UID)
97+
}
98+
99+
func (sgi *SubGroupInfo) WithPodInfos(tasks pod_info.PodsMap) *SubGroupInfo {
100+
for _, oldTask := range sgi.podInfos {
101+
sgi.clearOldStatus(oldTask)
102+
}
103+
104+
for _, task := range tasks {
105+
sgi.AssignTask(task)
106+
}
52107
return sgi
53108
}
54109

@@ -63,51 +118,21 @@ func (sgi *SubGroupInfo) IsGangSatisfied() bool {
63118
}
64119

65120
func (sgi *SubGroupInfo) GetNumActiveAllocatedTasks() int {
66-
taskCount := 0
67-
for _, task := range sgi.podInfos {
68-
if pod_status.IsActiveAllocatedStatus(task.Status) {
69-
taskCount++
70-
}
71-
}
72-
return taskCount
121+
return sgi.numActiveAllocatedTasks
73122
}
74123

75124
func (sgi *SubGroupInfo) GetNumActiveUsedTasks() int {
76-
numTasks := 0
77-
for _, podInfo := range sgi.podInfos {
78-
if pod_status.IsActiveUsedStatus(podInfo.Status) {
79-
numTasks += 1
80-
}
81-
}
82-
return numTasks
125+
return sgi.numActiveUsedTasks
83126
}
84127

85128
func (sgi *SubGroupInfo) GetNumAliveTasks() int {
86-
numTasks := 0
87-
for _, task := range sgi.podInfos {
88-
if pod_status.IsAliveStatus(task.Status) {
89-
numTasks += 1
90-
}
91-
}
92-
return numTasks
129+
return sgi.numAliveTasks
93130
}
94131

95132
func (sgi *SubGroupInfo) GetNumGatedTasks() int {
96-
numTasks := 0
97-
for _, podInfo := range sgi.podInfos {
98-
if podInfo.Status == pod_status.Gated {
99-
numTasks += 1
100-
}
101-
}
102-
return numTasks
133+
return len(sgi.podStatusIndex[pod_status.Gated])
103134
}
104135

105136
func (sgi *SubGroupInfo) GetNumPendingTasks() int {
106-
numTasks := 0
107-
for _, podInfo := range sgi.podInfos {
108-
if podInfo.Status == pod_status.Pending {
109-
numTasks += 1
110-
}
111-
}
112-
return numTasks
137+
return len(sgi.podStatusIndex[pod_status.Pending])
113138
}

pkg/scheduler/api/podgroup_info/subgroup_info_test.go

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,49 @@ func TestFromSubGroup(t *testing.T) {
4545
}
4646
}
4747

48+
func TestWithPodInfos(t *testing.T) {
49+
sgi := NewSubGroupInfo("test", 1)
50+
51+
// Pre-populate with a pod that should be cleared
52+
sgi.AssignTask(&pod_info.PodInfo{UID: "old", Status: pod_status.Running})
53+
54+
// Prepare new pod infos
55+
p1 := &pod_info.PodInfo{UID: "pod1", Status: pod_status.Pending}
56+
p2 := &pod_info.PodInfo{UID: "pod2", Status: pod_status.Running}
57+
replacement := pod_info.PodsMap{
58+
"pod1": p1,
59+
"pod2": p2,
60+
}
61+
62+
sgi.WithPodInfos(replacement)
63+
64+
gotInfos := sgi.GetPodInfos()
65+
if len(gotInfos) != 2 {
66+
t.Errorf("expected len(podInfos)==2, got %d", len(gotInfos))
67+
}
68+
if gotInfos["pod1"] != p1 {
69+
t.Errorf("pod1 entry is not correct")
70+
}
71+
if gotInfos["pod2"] != p2 {
72+
t.Errorf("pod2 entry is not correct")
73+
}
74+
// Old pod should not be present
75+
if _, ok := gotInfos["old"]; ok {
76+
t.Errorf("expected old pod to be cleared by WithPodInfos")
77+
}
78+
79+
// Check counters based on status
80+
if want := 2; sgi.GetNumAliveTasks() != want {
81+
t.Errorf("GetNumAliveTasks() = %d, want %d", sgi.GetNumAliveTasks(), want)
82+
}
83+
if want := 1; sgi.GetNumActiveAllocatedTasks() != want {
84+
t.Errorf("GetNumActiveAllocatedTasks() = %d, want %d", sgi.GetNumActiveAllocatedTasks(), want)
85+
}
86+
if want := 1; sgi.GetNumPendingTasks() != want {
87+
t.Errorf("GetNumPendingTasks() = %d, want %d", sgi.GetNumPendingTasks(), want)
88+
}
89+
}
90+
4891
func TestGetName(t *testing.T) {
4992
name := "test-subgroup"
5093
sgi := NewSubGroupInfo(name, 1)
@@ -235,17 +278,18 @@ func TestGetNumActiveUsedTasks(t *testing.T) {
235278
pods := []*pod_info.PodInfo{
236279
{UID: "1", Status: pod_status.Running},
237280
{UID: "2", Status: pod_status.Pipelined},
238-
{UID: "3", Status: pod_status.Failed},
239-
{UID: "4", Status: pod_status.Succeeded},
281+
{UID: "3", Status: pod_status.Releasing},
282+
{UID: "4", Status: pod_status.Failed},
283+
{UID: "5", Status: pod_status.Succeeded},
240284
}
241285

242286
for _, pod := range pods {
243287
sgi.AssignTask(pod)
244288
}
245289

246-
expected := 2 // Running and Pipelined are active allocated statuses
247-
if got := sgi.GetNumActiveAllocatedTasks(); got != expected {
248-
t.Errorf("GetNumAliveTasks() = %v, want %v", got, expected)
290+
expected := 3 // Running, Pipelined, and Releasing are considered active used statuses
291+
if got := sgi.GetNumActiveUsedTasks(); got != expected {
292+
t.Errorf("GetNumActiveUsedTasks() = %v, want %v", got, expected)
249293
}
250294
}
251295

0 commit comments

Comments
 (0)