Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions cmd/nvidia-ctk/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func run(c *cli.Context, opts *options) error {
if err != nil {
return fmt.Errorf("invalid --set option %v: %w", set, err)
}
cfgToml.Set(key, value)
if value == nil {
_ = cfgToml.Delete(key)
} else {
cfgToml.Set(key, value)
}
}

if err := opts.EnsureOutputFolder(); err != nil {
Expand Down Expand Up @@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {

kind := field.Kind()
if len(setParts) != 2 {
if kind == reflect.Bool {
if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) {
return key, true, nil
}
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
}

value := setParts[1]
if kind == reflect.Pointer && value != "nil" {
kind = field.Elem().Kind()
}
switch kind {
case reflect.Pointer:
return key, nil, nil
case reflect.Bool:
b, err := strconv.ParseBool(value)
if err != nil {
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
}
return key, b, err
return key, b, nil
case reflect.String:
return key, value, nil
case reflect.Slice:
Expand Down Expand Up @@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro
if !ok {
continue
}
if v != tomlField {
if strings.SplitN(v, ",", 2)[0] != tomlField {
continue
}
if len(paths) == 1 {
Expand Down
3 changes: 3 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type Config struct {
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`

// Features allows for finer control over optional features.
Features features `toml:"features,omitempty"`
}

// GetConfigFilePath returns the path to the config file for the configured system
Expand Down
85 changes: 85 additions & 0 deletions internal/config/features.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package config

type featureName string

const (
FeatureGDS = featureName("gds")
FeatureMOFED = featureName("mofed")
FeatureNVSWITCH = featureName("nvswitch")
FeatureGDRCopy = featureName("gdrcopy")
)

// features specifies a set of named features.
type features struct {
GDS *feature `toml:"gds,omitempty"`
MOFED *feature `toml:"mofed,omitempty"`
NVSWITCH *feature `toml:"nvswitch,omitempty"`
GDRCopy *feature `toml:"gdrcopy,omitempty"`
}

type feature bool

// IsEnabled checks whether a specified named feature is enabled.
// An optional list of environments to check for feature-specific environment
// variables can also be supplied.
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
featureEnvvars := map[featureName]string{
FeatureGDS: "NVIDIA_GDS",
FeatureMOFED: "NVIDIA_MOFED",
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
FeatureGDRCopy: "NVIDIA_GDRCOPY",
}

envvar := featureEnvvars[n]
switch n {
case FeatureGDS:
return fs.GDS.isEnabled(envvar, in...)
case FeatureMOFED:
return fs.MOFED.isEnabled(envvar, in...)
case FeatureNVSWITCH:
return fs.NVSWITCH.isEnabled(envvar, in...)
case FeatureGDRCopy:
return fs.GDRCopy.isEnabled(envvar, in...)
default:
return false
}
}

// isEnabled checks whether a feature is enabled.
// If the enabled value is explicitly set, this is returned, otherwise the
// associated envvar is checked in the specified getenver for the string "enabled"
// A CUDA container / image can be passed here.
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
if f != nil {
return bool(*f)
}
if envvar == "" {
return false
}
for _, in := range ins {
if in.Getenv(envvar) == "enabled" {
return true
}
}
return false
}

type getenver interface {
Getenv(string) string
}
15 changes: 4 additions & 11 deletions internal/modifier/gated.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

const (
nvidiaGDSEnvvar = "NVIDIA_GDS"
nvidiaMOFEDEnvvar = "NVIDIA_MOFED"
nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH"
nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY"
)

// NewFeatureGatedModifier creates the modifiers for optional features.
// These include:
//
Expand All @@ -53,31 +46,31 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
devRoot := cfg.NVIDIAContainerCLIConfig.Root

if image.Getenv(nvidiaGDSEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureGDS, image) {
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)
Expand Down