Skip to content

Commit 7bffae9

Browse files
committed
[no-relnote] Refactor handling of feature flags
Signed-off-by: Evan Lezar <[email protected]>
1 parent d4664fe commit 7bffae9

File tree

2 files changed

+39
-34
lines changed

2 files changed

+39
-34
lines changed

internal/config/features.go

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package config
1919
type featureName string
2020

2121
const (
22-
FeatureGDS = featureName("gds")
23-
FeatureMOFED = featureName("mofed")
24-
FeatureNVSWITCH = featureName("nvswitch")
25-
FeatureGDRCopy = featureName("gdrcopy")
22+
FeatureGDS = featureName("gds")
23+
FeatureMOFED = featureName("mofed")
24+
FeatureNVSWITCH = featureName("nvswitch")
25+
FeatureGDRCopy = featureName("gdrcopy")
26+
FeatureIncludePersistencedSocket = featureName("include-persistenced-socket")
2627
)
2728

2829
// features specifies a set of named features.
@@ -31,53 +32,57 @@ type features struct {
3132
MOFED *feature `toml:"mofed,omitempty"`
3233
NVSWITCH *feature `toml:"nvswitch,omitempty"`
3334
GDRCopy *feature `toml:"gdrcopy,omitempty"`
35+
// IncludePersistencedSocket enables the injection of the nvidia-persistenced
36+
// socket into containers.
37+
IncludePersistencedSocket *feature `toml:"include-persistenced-socket,omitempty"`
3438
}
3539

3640
type feature bool
3741

38-
// IsEnabled checks whether a specified named feature is enabled.
42+
// IsEnabledInEnvironment checks whether a specified named feature is enabled.
3943
// An optional list of environments to check for feature-specific environment
4044
// variables can also be supplied.
41-
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
42-
featureEnvvars := map[featureName]string{
43-
FeatureGDS: "NVIDIA_GDS",
44-
FeatureMOFED: "NVIDIA_MOFED",
45-
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
46-
FeatureGDRCopy: "NVIDIA_GDRCOPY",
47-
}
48-
49-
envvar := featureEnvvars[n]
45+
func (fs features) IsEnabledInEnvironment(n featureName, in ...getenver) bool {
5046
switch n {
47+
// Features with envvar overrides
5148
case FeatureGDS:
52-
return fs.GDS.isEnabled(envvar, in...)
49+
return fs.GDS.isEnabledWithEnvvarOverride("NVIDIA_GDS", in...)
5350
case FeatureMOFED:
54-
return fs.MOFED.isEnabled(envvar, in...)
51+
return fs.MOFED.isEnabledWithEnvvarOverride("NVIDIA_MOFED", in...)
5552
case FeatureNVSWITCH:
56-
return fs.NVSWITCH.isEnabled(envvar, in...)
53+
return fs.NVSWITCH.isEnabledWithEnvvarOverride("NVIDIA_NVSWITCH", in...)
5754
case FeatureGDRCopy:
58-
return fs.GDRCopy.isEnabled(envvar, in...)
55+
return fs.GDRCopy.isEnabledWithEnvvarOverride("NVIDIA_GDRCOPY", in...)
56+
// Features without envvar overrides
57+
case FeatureIncludePersistencedSocket:
58+
return fs.IncludePersistencedSocket.IsEnabled()
5959
default:
6060
return false
6161
}
6262
}
6363

64-
// isEnabled checks whether a feature is enabled.
65-
// If the enabled value is explicitly set, this is returned, otherwise the
66-
// associated envvar is checked in the specified getenver for the string "enabled"
67-
// A CUDA container / image can be passed here.
68-
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
64+
// IsEnabled checks whether a feature is enabled.
65+
func (f *feature) IsEnabled() bool {
6966
if f != nil {
7067
return bool(*f)
7168
}
72-
if envvar == "" {
73-
return false
74-
}
75-
for _, in := range ins {
76-
if in.Getenv(envvar) == "enabled" {
77-
return true
69+
return false
70+
}
71+
72+
// isEnabledWithEnvvarOverride checks whether a feature is enabled and allows an envvar to overide the feature.
73+
// If the enabled value is explicitly set, this is returned, otherwise the
74+
// associated envvar is checked in the specified getenver for the string "enabled"
75+
// A CUDA container / image can be passed here.
76+
func (f *feature) isEnabledWithEnvvarOverride(envvar string, ins ...getenver) bool {
77+
if envvar != "" {
78+
for _, in := range ins {
79+
if in.Getenv(envvar) == "enabled" {
80+
return true
81+
}
7882
}
7983
}
80-
return false
84+
85+
return f.IsEnabled()
8186
}
8287

8388
type getenver interface {

internal/modifier/gated.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,31 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
4646
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
4747
devRoot := cfg.NVIDIAContainerCLIConfig.Root
4848

49-
if cfg.Features.IsEnabled(config.FeatureGDS, image) {
49+
if cfg.Features.IsEnabledInEnvironment(config.FeatureGDS, image) {
5050
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
5151
if err != nil {
5252
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
5353
}
5454
discoverers = append(discoverers, d)
5555
}
5656

57-
if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
57+
if cfg.Features.IsEnabledInEnvironment(config.FeatureMOFED, image) {
5858
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
5959
if err != nil {
6060
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
6161
}
6262
discoverers = append(discoverers, d)
6363
}
6464

65-
if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
65+
if cfg.Features.IsEnabledInEnvironment(config.FeatureNVSWITCH, image) {
6666
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
6767
if err != nil {
6868
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
6969
}
7070
discoverers = append(discoverers, d)
7171
}
7272

73-
if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
73+
if cfg.Features.IsEnabledInEnvironment(config.FeatureGDRCopy, image) {
7474
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
7575
if err != nil {
7676
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)

0 commit comments

Comments
 (0)