Skip to content

Commit 413da20

Browse files
authored
Merge pull request #362 from elezar/add-feature-flags
Add support for feature flags
2 parents c374520 + 09341a0 commit 413da20

File tree

4 files changed

+105
-15
lines changed

4 files changed

+105
-15
lines changed

cmd/nvidia-ctk/config/config.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ func run(c *cli.Context, opts *options) error {
109109
if err != nil {
110110
return fmt.Errorf("invalid --set option %v: %w", set, err)
111111
}
112-
cfgToml.Set(key, value)
112+
if value == nil {
113+
_ = cfgToml.Delete(key)
114+
} else {
115+
cfgToml.Set(key, value)
116+
}
113117
}
114118

115119
if err := opts.EnsureOutputFolder(); err != nil {
@@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {
146150

147151
kind := field.Kind()
148152
if len(setParts) != 2 {
149-
if kind == reflect.Bool {
153+
if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) {
150154
return key, true, nil
151155
}
152156
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
153157
}
154158

155159
value := setParts[1]
160+
if kind == reflect.Pointer && value != "nil" {
161+
kind = field.Elem().Kind()
162+
}
156163
switch kind {
164+
case reflect.Pointer:
165+
return key, nil, nil
157166
case reflect.Bool:
158167
b, err := strconv.ParseBool(value)
159168
if err != nil {
160169
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
161170
}
162-
return key, b, err
171+
return key, b, nil
163172
case reflect.String:
164173
return key, value, nil
165174
case reflect.Slice:
@@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro
201210
if !ok {
202211
continue
203212
}
204-
if v != tomlField {
213+
if strings.SplitN(v, ",", 2)[0] != tomlField {
205214
continue
206215
}
207216
if len(paths) == 1 {

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ type Config struct {
6363
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
6464
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
6565
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`
66+
67+
// Features allows for finer control over optional features.
68+
Features features `toml:"features,omitempty"`
6669
}
6770

6871
// GetConfigFilePath returns the path to the config file for the configured system

internal/config/features.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package config
18+
19+
type featureName string
20+
21+
const (
22+
FeatureGDS = featureName("gds")
23+
FeatureMOFED = featureName("mofed")
24+
FeatureNVSWITCH = featureName("nvswitch")
25+
FeatureGDRCopy = featureName("gdrcopy")
26+
)
27+
28+
// features specifies a set of named features.
29+
type features struct {
30+
GDS *feature `toml:"gds,omitempty"`
31+
MOFED *feature `toml:"mofed,omitempty"`
32+
NVSWITCH *feature `toml:"nvswitch,omitempty"`
33+
GDRCopy *feature `toml:"gdrcopy,omitempty"`
34+
}
35+
36+
type feature bool
37+
38+
// IsEnabled checks whether a specified named feature is enabled.
39+
// An optional list of environments to check for feature-specific environment
40+
// variables can also be supplied.
41+
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
42+
featureEnvvars := map[featureName]string{
43+
FeatureGDS: "NVIDIA_GDS",
44+
FeatureMOFED: "NVIDIA_MOFED",
45+
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
46+
FeatureGDRCopy: "NVIDIA_GDRCOPY",
47+
}
48+
49+
envvar := featureEnvvars[n]
50+
switch n {
51+
case FeatureGDS:
52+
return fs.GDS.isEnabled(envvar, in...)
53+
case FeatureMOFED:
54+
return fs.MOFED.isEnabled(envvar, in...)
55+
case FeatureNVSWITCH:
56+
return fs.NVSWITCH.isEnabled(envvar, in...)
57+
case FeatureGDRCopy:
58+
return fs.GDRCopy.isEnabled(envvar, in...)
59+
default:
60+
return false
61+
}
62+
}
63+
64+
// isEnabled checks whether a feature is enabled.
65+
// If the enabled value is explicitly set, this is returned, otherwise the
66+
// associated envvar is checked in the specified getenver for the string "enabled"
67+
// A CUDA container / image can be passed here.
68+
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
69+
if f != nil {
70+
return bool(*f)
71+
}
72+
if envvar == "" {
73+
return false
74+
}
75+
for _, in := range ins {
76+
if in.Getenv(envvar) == "enabled" {
77+
return true
78+
}
79+
}
80+
return false
81+
}
82+
83+
type getenver interface {
84+
Getenv(string) string
85+
}

internal/modifier/gated.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ import (
2626
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2727
)
2828

29-
const (
30-
nvidiaGDSEnvvar = "NVIDIA_GDS"
31-
nvidiaMOFEDEnvvar = "NVIDIA_MOFED"
32-
nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH"
33-
nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY"
34-
)
35-
3629
// NewFeatureGatedModifier creates the modifiers for optional features.
3730
// These include:
3831
//
@@ -53,31 +46,31 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
5346
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
5447
devRoot := cfg.NVIDIAContainerCLIConfig.Root
5548

56-
if image.Getenv(nvidiaGDSEnvvar) == "enabled" {
49+
if cfg.Features.IsEnabled(config.FeatureGDS, image) {
5750
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
5851
if err != nil {
5952
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
6053
}
6154
discoverers = append(discoverers, d)
6255
}
6356

64-
if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" {
57+
if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
6558
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
6659
if err != nil {
6760
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
6861
}
6962
discoverers = append(discoverers, d)
7063
}
7164

72-
if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" {
65+
if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
7366
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
7467
if err != nil {
7568
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
7669
}
7770
discoverers = append(discoverers, d)
7871
}
7972

80-
if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" {
73+
if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
8174
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
8275
if err != nil {
8376
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)

0 commit comments

Comments
 (0)