Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type options struct {
ignorePatterns []string
}

deviceIDs []string

noAllDevice bool

// the following are used for dependency injection during spec generation.
nvmllib nvml.Interface
}
Expand Down Expand Up @@ -232,6 +236,20 @@ func (m command) build() *cli.Command {
Destination: &opts.featureFlags,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS"),
},
&cli.StringSliceFlag{
Name: "device-id",
Aliases: []string{"device-ids", "device", "devices"},
Usage: "Restrict generation to the specified device identifiers",
Value: []string{"all"},
Destination: &opts.deviceIDs,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"),
},
&cli.BoolFlag{
Name: "no-all-device",
Usage: "Don't generate an `all` device for the resultant spec",
Destination: &opts.noAllDevice,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_NO_ALL_DEVICE"),
},
},
}

Expand Down Expand Up @@ -373,7 +391,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
return nil, fmt.Errorf("failed to create CDI library: %v", err)
}

allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...)
if err != nil {
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
}
Expand All @@ -387,13 +405,18 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
spec.WithVendor(opts.vendor),
spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithFormat(opts.format),
spec.WithMergedDeviceOptions(
transform.WithName(allDeviceName),
transform.WithSkipIfExists(true),
),
spec.WithPermissions(0644),
}

if !opts.noAllDevice {
commonSpecOptions = append(commonSpecOptions,
spec.WithMergedDeviceOptions(
transform.WithName(allDeviceName),
transform.WithSkipIfExists(true),
),
)
}

fullSpec, err := spec.New(
append(commonSpecOptions,
spec.WithClass(opts.class),
Expand Down
4 changes: 4 additions & 0 deletions cmd/nvidia-ctk/cdi/generate/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ containerEdits:
for _, tc := range testCases {
// Apply overrides for all test cases:
tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
if tc.options.deviceIDs == nil {
tc.options.deviceIDs = []string{"all"}
tc.expectedOptions.deviceIDs = []string{"all"}
}

t.Run(tc.description, func(t *testing.T) {
c := command{
Expand Down
69 changes: 31 additions & 38 deletions internal/platform-support/tegra/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,30 @@
package tegra

import (
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
)

// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied.
// The constructed discoverer is comprised of a list, with each element in the list being associated with a
// single CSV files.
func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
if len(o.csvFiles) == 0 {
o.logger.Warningf("No CSV files specified")
func (o options) newDiscovererFromMountSpecs() (discover.Discover, error) {
pathsByType := o.MountSpecPathsByType()
if len(pathsByType) == 0 {
o.logger.Warningf("No mount specs specified")
return discover.None{}, nil
}

targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles)

devices := discover.NewCharDeviceDiscoverer(
o.logger,
o.devRoot,
targetsByType[csv.MountSpecDev],
pathsByType[csv.MountSpecDev],
)

directories := discover.NewMounts(
o.logger,
lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
o.driverRoot,
targetsByType[csv.MountSpecDir],
pathsByType[csv.MountSpecDir],
)

// We create a discoverer for mounted libraries and add additional .so
Expand All @@ -57,14 +51,14 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
o.logger,
o.symlinkLocator,
o.driverRoot,
targetsByType[csv.MountSpecLib],
pathsByType[csv.MountSpecLib],
),
"",
o.hookCreator,
)

// We process the explicitly requested symlinks.
symlinkTargets := o.ignorePatterns.Apply(targetsByType[csv.MountSpecSym]...)
symlinkTargets := pathsByType[csv.MountSpecSym]
o.logger.Debugf("Filtered symlink targets: %v", symlinkTargets)
symlinks := discover.NewMounts(
o.logger,
Expand All @@ -85,35 +79,34 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
return d, nil
}

// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files.
// These are aggregated by mount spec type.
// TODO: We use a function variable here to allow this to be overridden for testing.
// This should be properly mocked.
var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
targetsByType := make(map[csv.MountSpecType][]string)
for _, filename := range files {
targets, err := loadCSVFile(logger, filename)
if err != nil {
logger.Warningf("Skipping CSV file %v: %v", filename, err)
continue
}
for _, t := range targets {
targetsByType[t.Type] = append(targetsByType[t.Type], t.Path)
}
// MountSpecsFromCSVFiles returns a MountSpecPathsByTyper for the specified list
// of CSV files.
func MountSpecsFromCSVFiles(logger logger.Interface, csvFiles ...string) MountSpecPathsByTyper {
var tts []MountSpecPathsByTyper

for _, filename := range csvFiles {
tts = append(tts, &fromCSVFile{logger, filename})
}
return targetsByType
return Merge(tts...)
}

// loadCSVFile loads the specified CSV file and returns the list of mount specs
func loadCSVFile(logger logger.Interface, filename string) ([]*csv.MountSpec, error) {
type fromCSVFile struct {
logger logger.Interface
filename string
}

// MountSpecPathsByType returns mountspecs defined in the specified CSV file.
func (t *fromCSVFile) MountSpecPathsByType() MountSpecPathsByType {
// Create a discoverer for each file-kind combination
targets, err := csv.NewCSVFileParser(logger, filename).Parse()
targets, err := csv.NewCSVFileParser(t.logger, t.filename).Parse()
if err != nil {
return nil, fmt.Errorf("failed to parse CSV file: %v", err)
}
if len(targets) == 0 {
return nil, fmt.Errorf("CSV file is empty")
t.logger.Warningf("failed to parse CSV file %v: %v", t.filename, err)
return nil
}

return targets, nil
targetsByType := make(MountSpecPathsByType)
for _, t := range targets {
targetsByType[t.Type] = append(targetsByType[t.Type], t.Path)
}
return targetsByType
}
30 changes: 9 additions & 21 deletions internal/platform-support/tegra/csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"

"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
Expand All @@ -34,7 +33,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
moutSpecs map[csv.MountSpecType][]string
moutSpecs MountSpecPathsByType
ignorePatterns []string
symlinkLocator lookup.Locator
symlinkChainLocator lookup.Locator
Expand Down Expand Up @@ -186,19 +185,19 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
hookCreator := discover.NewHookCreator()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()

o := tegraOptions{
logger: logger,
hookCreator: hookCreator,
csvFiles: []string{"dummy"},
ignorePatterns: tc.ignorePatterns,
o := options{
logger: logger,
hookCreator: hookCreator,
MountSpecPathsByTyper: Filter(
tc.moutSpecs,
Symlinks(tc.ignorePatterns...),
),
symlinkLocator: tc.symlinkLocator,
symlinkChainLocator: tc.symlinkChainLocator,
resolveSymlink: tc.symlinkResolver,
}

d, err := o.newDiscovererFromCSVFiles()
d, err := o.newDiscovererFromMountSpecs()
require.ErrorIs(t, err, tc.expectedError)

hooks, err := d.Hooks()
Expand All @@ -212,14 +211,3 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
})
}
}

func setGetTargetsFromCSVFiles(override map[csv.MountSpecType][]string) func() {
original := getTargetsFromCSVFiles
getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
return override
}

return func() {
getTargetsFromCSVFiles = original
}
}
Loading