Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/env-tests/timeaware/timeaware.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ func setupControllers(backgroundCtx context.Context, cfg *rest.Config,
ClientType: "fake-with-history",
ConnectionString: "fake-connection",
UsageParams: &api.UsageParams{
WindowSize: &[]time.Duration{time.Second * time.Duration(*windowSize)}[0],
FetchInterval: &[]time.Duration{time.Millisecond}[0],
HalfLifePeriod: &[]time.Duration{time.Second * time.Duration(*halfLifePeriod)}[0],
WindowSize: &metav1.Duration{Duration: time.Second * time.Duration(*windowSize)},
FetchInterval: &metav1.Duration{Duration: time.Millisecond},
HalfLifePeriod: &metav1.Duration{Duration: time.Second * time.Duration(*halfLifePeriod)},
},
}
schedulerConf.UsageDBConfig.UsageParams.SetDefaults()
Expand Down
21 changes: 11 additions & 10 deletions pkg/operator/operands/scheduler/resources_for_shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strconv"
"strings"

"github.com/spf13/pflag"
"golang.org/x/exp/slices"

"gopkg.in/yaml.v3"
Expand All @@ -22,7 +23,7 @@ import (
kaiv1 "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1"
kaiConfigUtils "github.com/NVIDIA/KAI-scheduler/pkg/operator/config"
"github.com/NVIDIA/KAI-scheduler/pkg/operator/operands/common"
"github.com/spf13/pflag"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/conf"
)

const (
Expand Down Expand Up @@ -125,7 +126,7 @@ func (s *SchedulerForShard) configMapForShard(
APIVersion: "v1",
}
placementArguments := calculatePlacementArguments(shard.Spec.PlacementStrategy)
innerConfig := config{}
innerConfig := conf.SchedulerConfiguration{}

actions := []string{"allocate"}
if placementArguments[gpuResource] != spreadStrategy && placementArguments[cpuResource] != spreadStrategy {
Expand All @@ -135,9 +136,9 @@ func (s *SchedulerForShard) configMapForShard(

innerConfig.Actions = strings.Join(actions, ", ")

innerConfig.Tiers = []tier{
innerConfig.Tiers = []conf.Tier{
{
Plugins: []plugin{
Plugins: []conf.PluginOption{
{Name: "predicates"},
{Name: "proportion"},
{Name: "priority"},
Expand All @@ -160,8 +161,8 @@ func (s *SchedulerForShard) configMapForShard(

innerConfig.Tiers[0].Plugins = append(
innerConfig.Tiers[0].Plugins,
plugin{Name: fmt.Sprintf("gpu%s", strings.Replace(placementArguments[gpuResource], "bin", "", 1))},
plugin{
conf.PluginOption{Name: fmt.Sprintf("gpu%s", strings.Replace(placementArguments[gpuResource], "bin", "", 1))},
conf.PluginOption{
Name: "nodeplacement",
Arguments: placementArguments,
},
Expand All @@ -170,7 +171,7 @@ func (s *SchedulerForShard) configMapForShard(
if placementArguments[gpuResource] == binpackStrategy {
innerConfig.Tiers[0].Plugins = append(
innerConfig.Tiers[0].Plugins,
plugin{Name: "gpusharingorder"},
conf.PluginOption{Name: "gpusharingorder"},
)
}

Expand All @@ -195,7 +196,7 @@ func (s *SchedulerForShard) configMapForShard(
return schedulerConfig, nil
}

func validateJobDepthMap(shard *kaiv1.SchedulingShard, innerConfig config, actions []string) error {
func validateJobDepthMap(shard *kaiv1.SchedulingShard, innerConfig conf.SchedulerConfiguration, actions []string) error {
for actionToConfigure := range shard.Spec.QueueDepthPerAction {
if !slices.Contains(actions, actionToConfigure) {
return fmt.Errorf(invalidJobDepthMapError, innerConfig.Actions, actionToConfigure)
Expand Down Expand Up @@ -294,12 +295,12 @@ func calculatePlacementArguments(placementStrategy *kaiv1.PlacementStrategy) map
}
}

func addMinRuntimePluginIfNeeded(plugins *[]plugin, minRuntime *kaiv1.MinRuntime) {
func addMinRuntimePluginIfNeeded(plugins *[]conf.PluginOption, minRuntime *kaiv1.MinRuntime) {
if minRuntime == nil || (minRuntime.PreemptMinRuntime == nil && minRuntime.ReclaimMinRuntime == nil) {
return
}

minRuntimePlugin := plugin{Name: "minruntime", Arguments: map[string]string{}}
minRuntimePlugin := conf.PluginOption{Name: "minruntime", Arguments: map[string]string{}}

if minRuntime.PreemptMinRuntime != nil {
minRuntimePlugin.Arguments["defaultPreemptMinRuntime"] = *minRuntime.PreemptMinRuntime
Expand Down
7 changes: 4 additions & 3 deletions pkg/operator/operands/scheduler/resources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
kaiv1 "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1"
kaiv1qc "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1/queue_controller"
kaiv1scheduler "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1/scheduler"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/conf"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -187,7 +188,7 @@ func TestValidateJobDepthMap(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
innerConfig := config{
innerConfig := conf.SchedulerConfiguration{
Actions: strings.Join(tt.actions, ", "),
}

Expand Down Expand Up @@ -454,15 +455,15 @@ tiers:
require.True(t, found, "ConfigMap missing config.yaml")

// Unmarshal expected YAML from test case
var expectedConfig config
var expectedConfig conf.SchedulerConfiguration
if _, ok := tt.expected["config.yaml"]; !ok {
t.Fatal("Test case must provide expected YAML for config.yaml")
}
err = yaml.Unmarshal([]byte(tt.expected["config.yaml"]), &expectedConfig)
require.NoError(t, err, "Failed to unmarshal expected config")

// Unmarshal actual YAML from ConfigMap
var actualConfig config
var actualConfig conf.SchedulerConfiguration
err = yaml.Unmarshal([]byte(actualYAML), &actualConfig)
require.NoError(t, err, "Failed to unmarshal actual config")

Expand Down
15 changes: 0 additions & 15 deletions pkg/operator/operands/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ const (
defaultResourceName = "scheduler"
)

type config struct {
Actions string `yaml:"actions"`
Tiers []tier `yaml:"tiers,omitempty"`
QueueDepthPerAction map[string]int `yaml:"queueDepthPerAction,omitempty"`
}

type tier struct {
Plugins []plugin `yaml:"plugins"`
}

type plugin struct {
Name string `yaml:"name"`
Arguments map[string]string `yaml:"arguments,omitempty"`
}

type SchedulerForShard struct {
schedulingShard *kaiv1.SchedulingShard

Expand Down
4 changes: 2 additions & 2 deletions pkg/operator/operands/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ var _ = Describe("Scheduler", func() {
cm := cmObj.(*v1.ConfigMap)

Expect(err).To(BeNil())
Expect(cm.Data["config.yaml"]).To(Equal(`actions: allocate, consolidation, reclaim, preempt, stalegangeviction
Expect(cm.Data["config.yaml"]).To(MatchYAML(`actions: allocate, consolidation, reclaim, preempt, stalegangeviction
tiers:
- plugins:
- name: predicates
Expand Down Expand Up @@ -176,7 +176,7 @@ tiers:
cm := cmObj.(*v1.ConfigMap)

Expect(err).To(BeNil())
Expect(cm.Data["config.yaml"]).To(Equal(`actions: allocate, reclaim, preempt, stalegangeviction
Expect(cm.Data["config.yaml"]).To(MatchYAML(`actions: allocate, reclaim, preempt, stalegangeviction
tiers:
- plugins:
- name: predicates
Expand Down
6 changes: 3 additions & 3 deletions pkg/scheduler/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ func newSchedulerCache(schedulerCacheParams *SchedulerCacheParams) *SchedulerCac

if schedulerCacheParams.UsageDBClient != nil {
sc.usageLister = usagedb.NewUsageLister(schedulerCacheParams.UsageDBClient,
schedulerCacheParams.UsageDBParams.FetchInterval,
schedulerCacheParams.UsageDBParams.StalenessPeriod,
schedulerCacheParams.UsageDBParams.WaitTimeout)
&schedulerCacheParams.UsageDBParams.FetchInterval.Duration,
&schedulerCacheParams.UsageDBParams.StalenessPeriod.Duration,
&schedulerCacheParams.UsageDBParams.WaitTimeout.Duration)
}

clusterInfo, err := cluster_info.New(sc.informerFactory, sc.kubeAiSchedulerInformerFactory, sc.kueueInformerFactory, sc.usageLister, sc.schedulingNodePoolParams,
Expand Down
46 changes: 23 additions & 23 deletions pkg/scheduler/cache/usagedb/api/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@

package api

import "time"
import (
"time"

func (up *UsageParams) SetDefaults() {
if up.HalfLifePeriod == nil {
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func (p *UsageParams) SetDefaults() {
if p.HalfLifePeriod == nil {
// noop: disabled by default
}
if up.WindowSize == nil {
windowSize := time.Hour * 24 * 7
up.WindowSize = &windowSize
if p.WindowSize == nil {
p.WindowSize = &metav1.Duration{Duration: time.Hour * 24 * 7}
}
if up.WindowType == nil {
if p.WindowType == nil {
windowType := SlidingWindow
up.WindowType = &windowType
p.WindowType = &windowType
}
if up.FetchInterval == nil {
fetchInterval := 1 * time.Minute
up.FetchInterval = &fetchInterval
if p.FetchInterval == nil {
p.FetchInterval = &metav1.Duration{Duration: 1 * time.Minute}
}
if up.StalenessPeriod == nil {
stalenessPeriod := 5 * time.Minute
up.StalenessPeriod = &stalenessPeriod
if p.StalenessPeriod == nil {
p.StalenessPeriod = &metav1.Duration{Duration: 5 * time.Minute}
}
if up.WaitTimeout == nil {
waitTimeout := 1 * time.Minute
up.WaitTimeout = &waitTimeout
if p.WaitTimeout == nil {
p.WaitTimeout = &metav1.Duration{Duration: 1 * time.Minute}
}
}

Expand All @@ -54,12 +54,12 @@ func (wt WindowType) IsValid() bool {
}
}

func (up *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue time.Duration) time.Duration {
if up.ExtraParams == nil {
func (p *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue time.Duration) time.Duration {
if p.ExtraParams == nil {
return defaultValue
}

value, exists := up.ExtraParams[key]
value, exists := p.ExtraParams[key]
if !exists {
return defaultValue
}
Expand All @@ -72,12 +72,12 @@ func (up *UsageParams) GetExtraDurationParamOrDefault(key string, defaultValue t
return duration
}

func (up *UsageParams) GetExtraStringParamOrDefault(key string, defaultValue string) string {
if up.ExtraParams == nil {
func (p *UsageParams) GetExtraStringParamOrDefault(key string, defaultValue string) string {
if p.ExtraParams == nil {
return defaultValue
}

value, exists := up.ExtraParams[key]
value, exists := p.ExtraParams[key]
if !exists {
return defaultValue
}
Expand Down
70 changes: 58 additions & 12 deletions pkg/scheduler/cache/usagedb/api/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
package api

import (
"time"

"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/queue_info"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

type Interface interface {
Expand All @@ -15,8 +14,8 @@ type Interface interface {
type UsageDBConfig struct {
ClientType string `yaml:"clientType" json:"clientType"`
ConnectionString string `yaml:"connectionString" json:"connectionString"`
ConnectionStringEnvVar string `yaml:"connectionStringEnvVar" json:"connectionStringEnvVar"`
UsageParams *UsageParams `yaml:"usageParams" json:"usageParams"`
ConnectionStringEnvVar string `yaml:"connectionStringEnvVar,omitempty" json:"connectionStringEnvVar,omitempty"`
UsageParams *UsageParams `yaml:"usageParams,omitempty" json:"usageParams,omitempty"`
}

// GetUsageParams returns the usage params if set, and default params if not set.
Expand All @@ -29,23 +28,70 @@ func (c *UsageDBConfig) GetUsageParams() *UsageParams {
return &up
}

func (c *UsageDBConfig) DeepCopy() *UsageDBConfig {
out := new(UsageDBConfig)
out.ClientType = c.ClientType
out.ConnectionString = c.ConnectionString
out.ConnectionStringEnvVar = c.ConnectionStringEnvVar
if c.UsageParams != nil {
out.UsageParams = c.UsageParams.DeepCopy()
}
return out
}

// UsageParams defines common params for all usage db clients. Some clients may not support all the params.
type UsageParams struct {
// Half life period of the usage. If not set, or set to 0, the usage will not be decayed.
HalfLifePeriod *time.Duration `yaml:"halfLifePeriod" json:"halfLifePeriod"`
HalfLifePeriod *metav1.Duration `yaml:"halfLifePeriod,omitempty" json:"halfLifePeriod,omitempty"`
// Window size of the usage. Default is 1 week.
WindowSize *time.Duration `yaml:"windowSize" json:"windowSize"`
WindowSize *metav1.Duration `yaml:"windowSize,omitempty" json:"windowSize,omitempty"`
// Window type for time-series aggregation. If not set, defaults to sliding.
WindowType *WindowType `yaml:"windowType" json:"windowType"`
WindowType *WindowType `yaml:"windowType,omitempty" json:"windowType,omitempty"`
// A cron string used to determine when to reset resource usage for all queues.
TumblingWindowCronString string `yaml:"tumblingWindowCronString" json:"tumblingWindowCronString"`
TumblingWindowCronString string `yaml:"tumblingWindowCronString,omitempty" json:"tumblingWindowCronString,omitempty"`
// Fetch interval of the usage. Default is 1 minute.
FetchInterval *time.Duration `yaml:"fetchInterval" json:"fetchInterval"`
FetchInterval *metav1.Duration `yaml:"fetchInterval,omitempty" json:"fetchInterval,omitempty"`
// Staleness period of the usage. Default is 5 minutes.
StalenessPeriod *time.Duration `yaml:"stalenessPeriod" json:"stalenessPeriod"`
StalenessPeriod *metav1.Duration `yaml:"stalenessPeriod,omitempty" json:"stalenessPeriod,omitempty"`
// Wait timeout of the usage. Default is 1 minute.
WaitTimeout *time.Duration `yaml:"waitTimeout" json:"waitTimeout"`
WaitTimeout *metav1.Duration `yaml:"waitTimeout,omitempty" json:"waitTimeout,omitempty"`

// ExtraParams are extra parameters for the usage db client, which are client specific.
ExtraParams map[string]string `yaml:"extraParams" json:"extraParams"`
ExtraParams map[string]string `yaml:"extraParams,omitempty" json:"extraParams,omitempty"`
}

func (p *UsageParams) DeepCopy() *UsageParams {
out := new(UsageParams)
if p.HalfLifePeriod != nil {
duration := *p.HalfLifePeriod
out.HalfLifePeriod = &duration
}
if p.WindowSize != nil {
duration := *p.WindowSize
out.WindowSize = &duration
}
if p.WindowType != nil {
windowType := *p.WindowType
out.WindowType = &windowType
}
out.TumblingWindowCronString = p.TumblingWindowCronString
if p.FetchInterval != nil {
duration := *p.FetchInterval
out.FetchInterval = &duration
}
if p.StalenessPeriod != nil {
duration := *p.StalenessPeriod
out.StalenessPeriod = &duration
}
if p.WaitTimeout != nil {
duration := *p.WaitTimeout
out.WaitTimeout = &duration
}
if p.ExtraParams != nil {
out.ExtraParams = make(map[string]string, len(p.ExtraParams))
for k, v := range p.ExtraParams {
out.ExtraParams[k] = v
}
}
return out
}
Loading
Loading