Skip to content

Commit 4f003fc

Browse files
authored
refactor: Fix serialization of conf object (#633)
* Fix serialization of conf object
1 parent 9cade65 commit 4f003fc

File tree

14 files changed

+171
-128
lines changed

14 files changed

+171
-128
lines changed

pkg/env-tests/timeaware/timeaware.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ func setupControllers(backgroundCtx context.Context, cfg *rest.Config,
120120
ClientType: "fake-with-history",
121121
ConnectionString: "fake-connection",
122122
UsageParams: &api.UsageParams{
123-
WindowSize: &[]time.Duration{time.Second * time.Duration(*windowSize)}[0],
124-
FetchInterval: &[]time.Duration{time.Millisecond}[0],
125-
HalfLifePeriod: &[]time.Duration{time.Second * time.Duration(*halfLifePeriod)}[0],
123+
WindowSize: &metav1.Duration{Duration: time.Second * time.Duration(*windowSize)},
124+
FetchInterval: &metav1.Duration{Duration: time.Millisecond},
125+
HalfLifePeriod: &metav1.Duration{Duration: time.Second * time.Duration(*halfLifePeriod)},
126126
},
127127
}
128128
schedulerConf.UsageDBConfig.UsageParams.SetDefaults()

pkg/operator/operands/scheduler/resources_for_shard.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strconv"
1010
"strings"
1111

12+
"github.com/spf13/pflag"
1213
"golang.org/x/exp/slices"
1314

1415
"gopkg.in/yaml.v3"
@@ -22,7 +23,7 @@ import (
2223
kaiv1 "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1"
2324
kaiConfigUtils "github.com/NVIDIA/KAI-scheduler/pkg/operator/config"
2425
"github.com/NVIDIA/KAI-scheduler/pkg/operator/operands/common"
25-
"github.com/spf13/pflag"
26+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/conf"
2627
)
2728

2829
const (
@@ -125,7 +126,7 @@ func (s *SchedulerForShard) configMapForShard(
125126
APIVersion: "v1",
126127
}
127128
placementArguments := calculatePlacementArguments(shard.Spec.PlacementStrategy)
128-
innerConfig := config{}
129+
innerConfig := conf.SchedulerConfiguration{}
129130

130131
actions := []string{"allocate"}
131132
if placementArguments[gpuResource] != spreadStrategy && placementArguments[cpuResource] != spreadStrategy {
@@ -135,9 +136,9 @@ func (s *SchedulerForShard) configMapForShard(
135136

136137
innerConfig.Actions = strings.Join(actions, ", ")
137138

138-
innerConfig.Tiers = []tier{
139+
innerConfig.Tiers = []conf.Tier{
139140
{
140-
Plugins: []plugin{
141+
Plugins: []conf.PluginOption{
141142
{Name: "predicates"},
142143
{Name: "proportion"},
143144
{Name: "priority"},
@@ -160,8 +161,8 @@ func (s *SchedulerForShard) configMapForShard(
160161

161162
innerConfig.Tiers[0].Plugins = append(
162163
innerConfig.Tiers[0].Plugins,
163-
plugin{Name: fmt.Sprintf("gpu%s", strings.Replace(placementArguments[gpuResource], "bin", "", 1))},
164-
plugin{
164+
conf.PluginOption{Name: fmt.Sprintf("gpu%s", strings.Replace(placementArguments[gpuResource], "bin", "", 1))},
165+
conf.PluginOption{
165166
Name: "nodeplacement",
166167
Arguments: placementArguments,
167168
},
@@ -170,7 +171,7 @@ func (s *SchedulerForShard) configMapForShard(
170171
if placementArguments[gpuResource] == binpackStrategy {
171172
innerConfig.Tiers[0].Plugins = append(
172173
innerConfig.Tiers[0].Plugins,
173-
plugin{Name: "gpusharingorder"},
174+
conf.PluginOption{Name: "gpusharingorder"},
174175
)
175176
}
176177

@@ -195,7 +196,7 @@ func (s *SchedulerForShard) configMapForShard(
195196
return schedulerConfig, nil
196197
}
197198

198-
func validateJobDepthMap(shard *kaiv1.SchedulingShard, innerConfig config, actions []string) error {
199+
func validateJobDepthMap(shard *kaiv1.SchedulingShard, innerConfig conf.SchedulerConfiguration, actions []string) error {
199200
for actionToConfigure := range shard.Spec.QueueDepthPerAction {
200201
if !slices.Contains(actions, actionToConfigure) {
201202
return fmt.Errorf(invalidJobDepthMapError, innerConfig.Actions, actionToConfigure)
@@ -294,12 +295,12 @@ func calculatePlacementArguments(placementStrategy *kaiv1.PlacementStrategy) map
294295
}
295296
}
296297

297-
func addMinRuntimePluginIfNeeded(plugins *[]plugin, minRuntime *kaiv1.MinRuntime) {
298+
func addMinRuntimePluginIfNeeded(plugins *[]conf.PluginOption, minRuntime *kaiv1.MinRuntime) {
298299
if minRuntime == nil || (minRuntime.PreemptMinRuntime == nil && minRuntime.ReclaimMinRuntime == nil) {
299300
return
300301
}
301302

302-
minRuntimePlugin := plugin{Name: "minruntime", Arguments: map[string]string{}}
303+
minRuntimePlugin := conf.PluginOption{Name: "minruntime", Arguments: map[string]string{}}
303304

304305
if minRuntime.PreemptMinRuntime != nil {
305306
minRuntimePlugin.Arguments["defaultPreemptMinRuntime"] = *minRuntime.PreemptMinRuntime

pkg/operator/operands/scheduler/resources_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
kaiv1 "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1"
1818
kaiv1qc "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1/queue_controller"
1919
kaiv1scheduler "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1/scheduler"
20+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/conf"
2021

2122
"github.com/stretchr/testify/assert"
2223
"github.com/stretchr/testify/require"
@@ -187,7 +188,7 @@ func TestValidateJobDepthMap(t *testing.T) {
187188

188189
for _, tt := range tests {
189190
t.Run(tt.name, func(t *testing.T) {
190-
innerConfig := config{
191+
innerConfig := conf.SchedulerConfiguration{
191192
Actions: strings.Join(tt.actions, ", "),
192193
}
193194

@@ -454,15 +455,15 @@ tiers:
454455
require.True(t, found, "ConfigMap missing config.yaml")
455456

456457
// Unmarshal expected YAML from test case
457-
var expectedConfig config
458+
var expectedConfig conf.SchedulerConfiguration
458459
if _, ok := tt.expected["config.yaml"]; !ok {
459460
t.Fatal("Test case must provide expected YAML for config.yaml")
460461
}
461462
err = yaml.Unmarshal([]byte(tt.expected["config.yaml"]), &expectedConfig)
462463
require.NoError(t, err, "Failed to unmarshal expected config")
463464

464465
// Unmarshal actual YAML from ConfigMap
465-
var actualConfig config
466+
var actualConfig conf.SchedulerConfiguration
466467
err = yaml.Unmarshal([]byte(actualYAML), &actualConfig)
467468
require.NoError(t, err, "Failed to unmarshal actual config")
468469

pkg/operator/operands/scheduler/scheduler.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ const (
2222
defaultResourceName = "scheduler"
2323
)
2424

25-
type config struct {
26-
Actions string `yaml:"actions"`
27-
Tiers []tier `yaml:"tiers,omitempty"`
28-
QueueDepthPerAction map[string]int `yaml:"queueDepthPerAction,omitempty"`
29-
}
30-
31-
type tier struct {
32-
Plugins []plugin `yaml:"plugins"`
33-
}
34-
35-
type plugin struct {
36-
Name string `yaml:"name"`
37-
Arguments map[string]string `yaml:"arguments,omitempty"`
38-
}
39-
4025
type SchedulerForShard struct {
4126
schedulingShard *kaiv1.SchedulingShard
4227

pkg/operator/operands/scheduler/scheduler_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ var _ = Describe("Scheduler", func() {
137137
cm := cmObj.(*v1.ConfigMap)
138138

139139
Expect(err).To(BeNil())
140-
Expect(cm.Data["config.yaml"]).To(Equal(`actions: allocate, consolidation, reclaim, preempt, stalegangeviction
140+
Expect(cm.Data["config.yaml"]).To(MatchYAML(`actions: allocate, consolidation, reclaim, preempt, stalegangeviction
141141
tiers:
142142
- plugins:
143143
- name: predicates
@@ -176,7 +176,7 @@ tiers:
176176
cm := cmObj.(*v1.ConfigMap)
177177

178178
Expect(err).To(BeNil())
179-
Expect(cm.Data["config.yaml"]).To(Equal(`actions: allocate, reclaim, preempt, stalegangeviction
179+
Expect(cm.Data["config.yaml"]).To(MatchYAML(`actions: allocate, reclaim, preempt, stalegangeviction
180180
tiers:
181181
- plugins:
182182
- name: predicates

pkg/scheduler/cache/cache.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ func newSchedulerCache(schedulerCacheParams *SchedulerCacheParams) *SchedulerCac
169169

170170
if schedulerCacheParams.UsageDBClient != nil {
171171
sc.usageLister = usagedb.NewUsageLister(schedulerCacheParams.UsageDBClient,
172-
schedulerCacheParams.UsageDBParams.FetchInterval,
173-
schedulerCacheParams.UsageDBParams.StalenessPeriod,
174-
schedulerCacheParams.UsageDBParams.WaitTimeout)
172+
&schedulerCacheParams.UsageDBParams.FetchInterval.Duration,
173+
&schedulerCacheParams.UsageDBParams.StalenessPeriod.Duration,
174+
&schedulerCacheParams.UsageDBParams.WaitTimeout.Duration)
175175
}
176176

177177
clusterInfo, err := cluster_info.New(sc.informerFactory, sc.kubeAiSchedulerInformerFactory, sc.kueueInformerFactory, sc.usageLister, sc.schedulingNodePoolParams,

pkg/scheduler/cache/usagedb/api/defaults.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,31 @@
33

44
package api
55

6-
import "time"
6+
import (
7+
"time"
78

8-
func (up *UsageParams) SetDefaults() {
9-
if up.HalfLifePeriod == nil {
9+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
10+
)
11+
12+
func (p *UsageParams) SetDefaults() {
13+
if p.HalfLifePeriod == nil {
1014
// noop: disabled by default
1115
}
12-
if up.WindowSize == nil {
13-
windowSize := time.Hour * 24 * 7
14-
up.WindowSize = &windowSize
16+
if p.WindowSize == nil {
17+
p.WindowSize = &metav1.Duration{Duration: time.Hour * 24 * 7}
1518
}
16-
if up.WindowType == nil {
19+
if p.WindowType == nil {
1720
windowType := SlidingWindow
18-
up.WindowType = &windowType
21+
p.WindowType = &windowType
1922
}
20-
if up.FetchInterval == nil {
21-
fetchInterval := 1 * time.Minute
22-
up.FetchInterval = &fetchInterval
23+
if p.FetchInterval == nil {
24+
p.FetchInterval = &metav1.Duration{Duration: 1 * time.Minute}
2325
}
24-
if up.StalenessPeriod == nil {
25-
stalenessPeriod := 5 * time.Minute
26-
up.StalenessPeriod = &stalenessPeriod
26+
if p.StalenessPeriod == nil {
27+
p.StalenessPeriod = &metav1.Duration{Duration: 5 * time.Minute}
2728
}
28-
if up.WaitTimeout == nil {
29-
waitTimeout := 1 * time.Minute
30-
up.WaitTimeout = &waitTimeout
29+
if p.WaitTimeout == nil {
30+
p.WaitTimeout = &metav1.Duration{Duration: 1 * time.Minute}
3131
}
3232
}
3333

@@ -54,12 +54,12 @@ func (wt WindowType) IsValid() bool {
5454
}
5555
}
5656

57-
func (up *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue time.Duration) time.Duration {
58-
if up.ExtraParams == nil {
57+
func (p *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue time.Duration) time.Duration {
58+
if p.ExtraParams == nil {
5959
return defaultValue
6060
}
6161

62-
value, exists := up.ExtraParams[key]
62+
value, exists := p.ExtraParams[key]
6363
if !exists {
6464
return defaultValue
6565
}
@@ -72,12 +72,12 @@ func (up *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue t
7272
return duration
7373
}
7474

75-
func (up *UsageParams) GetExtraStringParamOrDefault(key string, defaultValue string) string {
76-
if up.ExtraParams == nil {
75+
func (p *UsageParams) GetExtraStringParamOrDefault(key string, defaultValue string) string {
76+
if p.ExtraParams == nil {
7777
return defaultValue
7878
}
7979

80-
value, exists := up.ExtraParams[key]
80+
value, exists := p.ExtraParams[key]
8181
if !exists {
8282
return defaultValue
8383
}

pkg/scheduler/cache/usagedb/api/interface.go

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
package api
55

66
import (
7-
"time"
8-
97
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/queue_info"
8+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
109
)
1110

1211
type Interface interface {
@@ -15,8 +14,8 @@ type Interface interface {
1514
type UsageDBConfig struct {
1615
ClientType string `yaml:"clientType" json:"clientType"`
1716
ConnectionString string `yaml:"connectionString" json:"connectionString"`
18-
ConnectionStringEnvVar string `yaml:"connectionStringEnvVar" json:"connectionStringEnvVar"`
19-
UsageParams *UsageParams `yaml:"usageParams" json:"usageParams"`
17+
ConnectionStringEnvVar string `yaml:"connectionStringEnvVar,omitempty" json:"connectionStringEnvVar,omitempty"`
18+
UsageParams *UsageParams `yaml:"usageParams,omitempty" json:"usageParams,omitempty"`
2019
}
2120

2221
// GetUsageParams returns the usage params if set, and default params if not set.
@@ -29,23 +28,70 @@ func (c *UsageDBConfig) GetUsageParams() *UsageParams {
2928
return &up
3029
}
3130

31+
func (c *UsageDBConfig) DeepCopy() *UsageDBConfig {
32+
out := new(UsageDBConfig)
33+
out.ClientType = c.ClientType
34+
out.ConnectionString = c.ConnectionString
35+
out.ConnectionStringEnvVar = c.ConnectionStringEnvVar
36+
if c.UsageParams != nil {
37+
out.UsageParams = c.UsageParams.DeepCopy()
38+
}
39+
return out
40+
}
41+
3242
// UsageParams defines common params for all usage db clients. Some clients may not support all the params.
3343
type UsageParams struct {
3444
// Half life period of the usage. If not set, or set to 0, the usage will not be decayed.
35-
HalfLifePeriod *time.Duration `yaml:"halfLifePeriod" json:"halfLifePeriod"`
45+
HalfLifePeriod *metav1.Duration `yaml:"halfLifePeriod,omitempty" json:"halfLifePeriod,omitempty"`
3646
// Window size of the usage. Default is 1 week.
37-
WindowSize *time.Duration `yaml:"windowSize" json:"windowSize"`
47+
WindowSize *metav1.Duration `yaml:"windowSize,omitempty" json:"windowSize,omitempty"`
3848
// Window type for time-series aggregation. If not set, defaults to sliding.
39-
WindowType *WindowType `yaml:"windowType" json:"windowType"`
49+
WindowType *WindowType `yaml:"windowType,omitempty" json:"windowType,omitempty"`
4050
// A cron string used to determine when to reset resource usage for all queues.
41-
TumblingWindowCronString string `yaml:"tumblingWindowCronString" json:"tumblingWindowCronString"`
51+
TumblingWindowCronString string `yaml:"tumblingWindowCronString,omitempty" json:"tumblingWindowCronString,omitempty"`
4252
// Fetch interval of the usage. Default is 1 minute.
43-
FetchInterval *time.Duration `yaml:"fetchInterval" json:"fetchInterval"`
53+
FetchInterval *metav1.Duration `yaml:"fetchInterval,omitempty" json:"fetchInterval,omitempty"`
4454
// Staleness period of the usage. Default is 5 minutes.
45-
StalenessPeriod *time.Duration `yaml:"stalenessPeriod" json:"stalenessPeriod"`
55+
StalenessPeriod *metav1.Duration `yaml:"stalenessPeriod,omitempty" json:"stalenessPeriod,omitempty"`
4656
// Wait timeout of the usage. Default is 1 minute.
47-
WaitTimeout *time.Duration `yaml:"waitTimeout" json:"waitTimeout"`
57+
WaitTimeout *metav1.Duration `yaml:"waitTimeout,omitempty" json:"waitTimeout,omitempty"`
4858

4959
// ExtraParams are extra parameters for the usage db client, which are client specific.
50-
ExtraParams map[string]string `yaml:"extraParams" json:"extraParams"`
60+
ExtraParams map[string]string `yaml:"extraParams,omitempty" json:"extraParams,omitempty"`
61+
}
62+
63+
func (p *UsageParams) DeepCopy() *UsageParams {
64+
out := new(UsageParams)
65+
if p.HalfLifePeriod != nil {
66+
duration := *p.HalfLifePeriod
67+
out.HalfLifePeriod = &duration
68+
}
69+
if p.WindowSize != nil {
70+
duration := *p.WindowSize
71+
out.WindowSize = &duration
72+
}
73+
if p.WindowType != nil {
74+
windowType := *p.WindowType
75+
out.WindowType = &windowType
76+
}
77+
out.TumblingWindowCronString = p.TumblingWindowCronString
78+
if p.FetchInterval != nil {
79+
duration := *p.FetchInterval
80+
out.FetchInterval = &duration
81+
}
82+
if p.StalenessPeriod != nil {
83+
duration := *p.StalenessPeriod
84+
out.StalenessPeriod = &duration
85+
}
86+
if p.WaitTimeout != nil {
87+
duration := *p.WaitTimeout
88+
out.WaitTimeout = &duration
89+
}
90+
if p.ExtraParams != nil {
91+
out.ExtraParams = make(map[string]string, len(p.ExtraParams))
92+
for k, v := range p.ExtraParams {
93+
out.ExtraParams[k] = v
94+
}
95+
}
96+
return out
5197
}

0 commit comments

Comments
 (0)