Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions internal/modifier/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,32 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

type list struct {
modifiers []oci.SpecModifier
}
type List []oci.SpecModifier

// Merge merges a set of OCI specification modifiers as a list.
// This can be used to compose modifiers.
func Merge(modifiers ...oci.SpecModifier) oci.SpecModifier {
var filteredModifiers []oci.SpecModifier
var filteredModifiers List
for _, m := range modifiers {
if m == nil {
continue
}
filteredModifiers = append(filteredModifiers, m)
}

return list{
modifiers: filteredModifiers,
}
return filteredModifiers
}

// Modify applies a list of modifiers in sequence and returns on any errors encountered.
func (m list) Modify(spec *specs.Spec) error {
for _, mm := range m.modifiers {
func (m List) Modify(spec *specs.Spec) error {
for _, mm := range m {
if mm == nil {
continue
}
err := mm.Modify(spec)
if err != nil {
return err
}
}

return nil
}
49 changes: 32 additions & 17 deletions internal/runtime/runtime_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,27 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
if err != nil {
return nil, err
}
// For CDI mode we make no additional modifications.
if mode == "cdi" {
return modeModifier, nil
}

graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
if err != nil {
return nil, err
}

featureModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image)
if err != nil {
return nil, err
var modifiers modifier.List
for _, modifierType := range supportedModifierTypes(mode) {
switch modifierType {
case "mode":
modifiers = append(modifiers, modeModifier)
case "graphics":
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
if err != nil {
return nil, err
}
modifiers = append(modifiers, graphicsModifier)
case "feature-gated":
featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image)
if err != nil {
return nil, err
}
modifiers = append(modifiers, featureGatedModifier)
}
}

modifiers := modifier.Merge(
modeModifier,
graphicsModifier,
featureModifier,
)
return modifiers, nil
}

Expand All @@ -114,3 +115,17 @@ func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, o

return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)
}

// supportedModifierTypes returns the modifiers supported for a specific runtime mode.
func supportedModifierTypes(mode string) []string {
switch mode {
case "cdi":
// For CDI mode we make no additional modifications.
return []string{"mode"}
case "csv":
// For CSV mode we support mode and feature-gated modification.
return []string{"mode", "feature-gated"}
default:
return []string{"mode", "graphics", "feature-gated"}
}
}
Loading