Skip to content

Commit c1df2a9

Browse files
authored
refactor out reclaimerinfo from API (#172)
1 parent ca093f7 commit c1df2a9

File tree

5 files changed

+43
-42
lines changed

5 files changed

+43
-42
lines changed

pkg/scheduler/actions/reclaim/reclaim.go

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
1313
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1414
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
15-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/reclaimer_info"
1615
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/framework"
1716
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
1817
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/metrics"
@@ -49,8 +48,7 @@ func (ra *reclaimAction) Execute(ssn *framework.Session) {
4948
for !jobsOrderByQueues.IsEmpty() {
5049
job := jobsOrderByQueues.PopNextJob()
5150

52-
reclaimerInfo := buildReclaimerInfo(ssn, job)
53-
if !ssn.CanReclaimResources(reclaimerInfo) {
51+
if !ssn.CanReclaimResources(job) {
5452
continue
5553
}
5654

@@ -69,7 +67,7 @@ func (ra *reclaimAction) Execute(ssn *framework.Session) {
6967
}
7068
}
7169
metrics.IncPodgroupsConsideredByAction()
72-
succeeded, statement, reclaimeeTasksNames := ra.attemptToReclaimForSpecificJob(ssn, job, reclaimerInfo)
70+
succeeded, statement, reclaimeeTasksNames := ra.attemptToReclaimForSpecificJob(ssn, job)
7371
if succeeded {
7472
metrics.IncPodgroupScheduledByAction()
7573
log.InfraLogger.V(3).Infof(
@@ -88,7 +86,7 @@ func (ra *reclaimAction) Execute(ssn *framework.Session) {
8886
}
8987

9088
func (ra *reclaimAction) attemptToReclaimForSpecificJob(
91-
ssn *framework.Session, reclaimer *podgroup_info.PodGroupInfo, reclaimerInfo *reclaimer_info.ReclaimerInfo,
89+
ssn *framework.Session, reclaimer *podgroup_info.PodGroupInfo,
9290
) (bool, *framework.Statement, []string) {
9391
queue := ssn.Queues[reclaimer.Queue]
9492
resReq := podgroup_info.GetTasksToAllocateInitResource(reclaimer, ssn.TaskOrderFn, false)
@@ -100,30 +98,19 @@ func (ra *reclaimAction) attemptToReclaimForSpecificJob(
10098
feasibleNodes := common.FeasibleNodesForJob(maps.Values(ssn.Nodes), reclaimer)
10199
solver := solvers.NewJobsSolver(
102100
feasibleNodes,
103-
reclaimableScenarioCheck(ssn, reclaimerInfo),
101+
reclaimableScenarioCheck(ssn, reclaimer),
104102
getOrderedVictimsQueue(ssn, reclaimer),
105103
framework.Reclaim)
106104
return solver.Solve(ssn, reclaimer)
107105
}
108106

109107
func reclaimableScenarioCheck(ssn *framework.Session,
110-
reclaimerInfo *reclaimer_info.ReclaimerInfo) solvers.SolutionValidator {
108+
reclaimer *podgroup_info.PodGroupInfo) solvers.SolutionValidator {
111109
return func(
112110
_ *podgroup_info.PodGroupInfo,
113111
victimJobs []*podgroup_info.PodGroupInfo,
114112
victimTasks []*pod_info.PodInfo) bool {
115-
return ssn.ReclaimScenarioValidator(reclaimerInfo, victimJobs, victimTasks)
116-
}
117-
}
118-
119-
func buildReclaimerInfo(ssn *framework.Session, reclaimerJob *podgroup_info.PodGroupInfo) *reclaimer_info.ReclaimerInfo {
120-
return &reclaimer_info.ReclaimerInfo{
121-
Name: reclaimerJob.Name,
122-
Namespace: reclaimerJob.Namespace,
123-
Queue: reclaimerJob.Queue,
124-
IsPreemptable: reclaimerJob.IsPreemptibleJob(ssn.IsInferencePreemptible()),
125-
RequiredResources: podgroup_info.GetTasksToAllocateInitResource(
126-
reclaimerJob, ssn.TaskOrderFn, false),
113+
return ssn.ReclaimScenarioValidator(reclaimer, victimJobs, victimTasks)
127114
}
128115
}
129116

pkg/scheduler/api/types.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1010
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
1111
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/queue_info"
12-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/reclaimer_info"
1312
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/resource_info"
1413
)
1514

@@ -20,16 +19,13 @@ type PredicateFn func(*pod_info.PodInfo, *podgroup_info.PodGroupInfo, *node_info
2019
type PrePredicateFn func(*pod_info.PodInfo, *podgroup_info.PodGroupInfo) error
2120

2221
// CanReclaimResourcesFn is a function that determines if a reclaimer can get more resources
23-
type CanReclaimResourcesFn func(reclaimerInfo *reclaimer_info.ReclaimerInfo) bool
22+
type CanReclaimResourcesFn func(reclaimer *podgroup_info.PodGroupInfo) bool
2423

25-
// VictimFilterFn is a function which filters out jobs that cannot a victim candidate for a specific reclaimer.
24+
// VictimFilterFn is a function which filters out jobs that cannot a victim candidate for a specific reclaimer/preemptor.
2625
type VictimFilterFn func(pendingJob *podgroup_info.PodGroupInfo, victim *podgroup_info.PodGroupInfo) bool
2726

28-
// ScenarioValidatorFn is a function which determines the validity of a reclaim scenario.
29-
type ScenarioValidatorFn func(reclaimerInfo *reclaimer_info.ReclaimerInfo, victims []*podgroup_info.PodGroupInfo, tasks []*pod_info.PodInfo) bool
30-
31-
// PreemptScenarioValidatorFn is a function which determines the validity of a preempt scenario.
32-
type PreemptScenarioValidatorFn func(preemptor *podgroup_info.PodGroupInfo, victims []*podgroup_info.PodGroupInfo, tasks []*pod_info.PodInfo) bool
27+
// ScenarioValidatorFn is a function which determines the validity of a scenario.
28+
type ScenarioValidatorFn func(pendingJob *podgroup_info.PodGroupInfo, victims []*podgroup_info.PodGroupInfo, tasks []*pod_info.PodInfo) bool
3329

3430
// QueueResource is a function which returns the resource of a queue.
3531
type QueueResource func(*queue_info.QueueInfo) *resource_info.ResourceRequirements

pkg/scheduler/framework/session.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ type Session struct {
5050
CanReclaimResourcesFns []api.CanReclaimResourcesFn
5151
ReclaimVictimFilterFns []api.VictimFilterFn
5252
PreemptVictimFilterFns []api.VictimFilterFn
53-
ReclaimScenarioValidators []api.ScenarioValidatorFn
54-
PreemptScenarioValidators []api.PreemptScenarioValidatorFn
53+
ReclaimScenarioValidatorFns []api.ScenarioValidatorFn
54+
PreemptScenarioValidatorFns []api.ScenarioValidatorFn
5555
OnJobSolutionStartFns []api.OnJobSolutionStartFn
5656
GetQueueAllocatedResourcesFns []api.QueueResource
5757
GetQueueDeservedResourcesFns []api.QueueResource

pkg/scheduler/framework/session_plugins.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1313
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
1414
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/queue_info"
15-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/reclaimer_info"
1615
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/resource_info"
1716
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
1817
)
@@ -68,14 +67,18 @@ func (ssn *Session) AddCanReclaimResourcesFn(crf api.CanReclaimResourcesFn) {
6867
}
6968

7069
func (ssn *Session) AddReclaimScenarioValidatorFn(rf api.ScenarioValidatorFn) {
71-
ssn.ReclaimScenarioValidators = append(ssn.ReclaimScenarioValidators, rf)
70+
ssn.ReclaimScenarioValidatorFns = append(ssn.ReclaimScenarioValidatorFns, rf)
71+
}
72+
73+
func (ssn *Session) AddPreemptScenarioValidatorFn(rf api.ScenarioValidatorFn) {
74+
ssn.PreemptScenarioValidatorFns = append(ssn.PreemptScenarioValidatorFns, rf)
7275
}
7376

7477
func (ssn *Session) AddReclaimeeFilterFn(rf api.VictimFilterFn) {
7578
ssn.ReclaimVictimFilterFns = append(ssn.ReclaimVictimFilterFns, rf)
7679
}
7780

78-
func (ssn *Session) CanReclaimResources(reclaimer *reclaimer_info.ReclaimerInfo) bool {
81+
func (ssn *Session) CanReclaimResources(reclaimer *podgroup_info.PodGroupInfo) bool {
7982
for _, canReclaimFn := range ssn.CanReclaimResourcesFns {
8083
return canReclaimFn(reclaimer)
8184
}
@@ -94,11 +97,11 @@ func (ssn *Session) ReclaimVictimFilter(reclaimer *podgroup_info.PodGroupInfo, v
9497
}
9598

9699
func (ssn *Session) ReclaimScenarioValidator(
97-
reclaimer *reclaimer_info.ReclaimerInfo,
100+
reclaimer *podgroup_info.PodGroupInfo,
98101
reclaimees []*podgroup_info.PodGroupInfo,
99102
victimsTasks []*pod_info.PodInfo,
100103
) bool {
101-
for _, rf := range ssn.ReclaimScenarioValidators {
104+
for _, rf := range ssn.ReclaimScenarioValidatorFns {
102105
if !rf(reclaimer, reclaimees, victimsTasks) {
103106
return false
104107
}
@@ -122,7 +125,7 @@ func (ssn *Session) PreemptScenarioValidator(
122125
victimJobs []*podgroup_info.PodGroupInfo,
123126
victimTasks []*pod_info.PodInfo,
124127
) bool {
125-
for _, pf := range ssn.PreemptScenarioValidators {
128+
for _, pf := range ssn.PreemptScenarioValidatorFns {
126129
if !pf(preemptor, victimJobs, victimTasks) {
127130
return false
128131
}

pkg/scheduler/plugins/proportion/proportion.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ type proportionPlugin struct {
3737
queues map[common_info.QueueID]*rs.QueueAttributes
3838
jobSimulationQueues map[common_info.QueueID]*rs.QueueAttributes
3939
// Arguments given for the plugin
40-
pluginArguments map[string]string
41-
taskOrderFunc common_info.LessFn
42-
reclaimablePlugin *rec.Reclaimable
40+
pluginArguments map[string]string
41+
taskOrderFunc common_info.LessFn
42+
reclaimablePlugin *rec.Reclaimable
43+
isInferencePreemptible bool
4344
}
4445

4546
func New(arguments map[string]string) framework.Plugin {
@@ -59,6 +60,7 @@ func (pp *proportionPlugin) OnSessionOpen(ssn *framework.Session) {
5960
pp.taskOrderFunc = ssn.TaskOrderFn
6061
pp.reclaimablePlugin = rec.New(ssn.IsInferencePreemptible(), pp.taskOrderFunc)
6162

63+
pp.isInferencePreemptible = ssn.IsInferencePreemptible()
6264
capacityPolicy := cp.New(pp.queues, ssn.IsInferencePreemptible())
6365
ssn.AddQueueOrderFn(pp.queueOrder)
6466
ssn.AddCanReclaimResourcesFn(pp.CanReclaimResourcesFn)
@@ -91,15 +93,17 @@ func (pp *proportionPlugin) OnJobSolutionStartFn() {
9193
}
9294
}
9395

94-
func (pp *proportionPlugin) CanReclaimResourcesFn(reclaimer *reclaimer_info.ReclaimerInfo) bool {
95-
return pp.reclaimablePlugin.CanReclaimResources(pp.queues, reclaimer)
96+
func (pp *proportionPlugin) CanReclaimResourcesFn(reclaimer *podgroup_info.PodGroupInfo) bool {
97+
reclaimerInfo := pp.buildReclaimerInfo(reclaimer)
98+
return pp.reclaimablePlugin.CanReclaimResources(pp.queues, reclaimerInfo)
9699
}
97100

98101
func (pp *proportionPlugin) reclaimableFn(
99-
reclaimer *reclaimer_info.ReclaimerInfo,
102+
reclaimer *podgroup_info.PodGroupInfo,
100103
reclaimees []*podgroup_info.PodGroupInfo,
101104
_ []*pod_info.PodInfo,
102105
) bool {
106+
reclaimerInfo := pp.buildReclaimerInfo(reclaimer)
103107
totalVictimsResources := make(map[common_info.QueueID][]*resource_info.Resource)
104108
for _, jobTaskGroup := range reclaimees {
105109
totalJobResources := resource_info.EmptyResource()
@@ -113,7 +117,7 @@ func (pp *proportionPlugin) reclaimableFn(
113117
)
114118
}
115119

116-
return pp.reclaimablePlugin.Reclaimable(pp.jobSimulationQueues, reclaimer, totalVictimsResources)
120+
return pp.reclaimablePlugin.Reclaimable(pp.jobSimulationQueues, reclaimerInfo, totalVictimsResources)
117121
}
118122

119123
func (pp *proportionPlugin) calculateResourcesProportion(ssn *framework.Session) {
@@ -180,6 +184,17 @@ func (pp *proportionPlugin) createQueueAttributes(ssn *framework.Session) {
180184
pp.setFairShare()
181185
}
182186

187+
func (pp *proportionPlugin) buildReclaimerInfo(reclaimer *podgroup_info.PodGroupInfo) *reclaimer_info.ReclaimerInfo {
188+
return &reclaimer_info.ReclaimerInfo{
189+
Name: reclaimer.Name,
190+
Namespace: reclaimer.Namespace,
191+
Queue: reclaimer.Queue,
192+
IsPreemptable: reclaimer.IsPreemptibleJob(pp.isInferencePreemptible),
193+
RequiredResources: podgroup_info.GetTasksToAllocateInitResource(
194+
reclaimer, pp.taskOrderFunc, false),
195+
}
196+
}
197+
183198
func (pp *proportionPlugin) createQueueResourceAttrs(ssn *framework.Session) {
184199
for _, queue := range ssn.Queues {
185200
queueAttributes := &rs.QueueAttributes{

0 commit comments

Comments
 (0)