Skip to content

Commit 963250a

Browse files
committed
Refactor CSV discovery for testability
This change improves the testibility of the CSV discoverer. This is done by adding injection points for mocks for library discovery and symlink resolution. Note that this highlights a bug in the current implementation where the library filter causes valid symlinks to be skipped. Signed-off-by: Evan Lezar <[email protected]>
1 parent be570fc commit 963250a

File tree

4 files changed

+186
-40
lines changed

4 files changed

+186
-40
lines changed

internal/platform-support/tegra/csv.go

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,51 +28,45 @@ import (
2828
// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied.
2929
// The constructed discoverer is comprised of a list, with each element in the list being associated with a
3030
// single CSV files.
31-
func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) {
32-
if len(files) == 0 {
33-
logger.Warningf("No CSV files specified")
31+
func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
32+
if len(o.csvFiles) == 0 {
33+
o.logger.Warningf("No CSV files specified")
3434
return discover.None{}, nil
3535
}
3636

37-
targetsByType := getTargetsFromCSVFiles(logger, files)
37+
targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles)
3838

3939
devices := discover.NewDeviceDiscoverer(
40-
logger,
41-
lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
42-
driverRoot,
40+
o.logger,
41+
lookup.NewCharDeviceLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
42+
o.driverRoot,
4343
targetsByType[csv.MountSpecDev],
4444
)
4545

4646
directories := discover.NewMounts(
47-
logger,
48-
lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
49-
driverRoot,
47+
o.logger,
48+
lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
49+
o.driverRoot,
5050
targetsByType[csv.MountSpecDir],
5151
)
5252

5353
// Libraries and symlinks use the same locator.
54-
searchPaths := append(librarySearchPaths, "/")
55-
symlinkLocator := lookup.NewSymlinkLocator(
56-
lookup.WithLogger(logger),
57-
lookup.WithRoot(driverRoot),
58-
lookup.WithSearchPaths(searchPaths...),
59-
)
6054
libraries := discover.NewMounts(
61-
logger,
62-
symlinkLocator,
63-
driverRoot,
55+
o.logger,
56+
o.symlinkLocator,
57+
o.driverRoot,
6458
targetsByType[csv.MountSpecLib],
6559
)
6660

6761
nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...)
68-
logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
62+
o.logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
6963
symlinks := discover.NewMounts(
70-
logger,
71-
symlinkLocator,
72-
driverRoot,
64+
o.logger,
65+
o.symlinkLocator,
66+
o.driverRoot,
7367
nonLibSymlinks,
7468
)
75-
createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath)
69+
createSymlinks := o.createCSVSymlinkHooks(nonLibSymlinks, libraries)
7670

7771
d := discover.Merge(
7872
devices,
@@ -87,7 +81,9 @@ func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRo
8781

8882
// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files.
8983
// These are aggregated by mount spec type.
90-
func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
84+
// TODO: We use a function variable here to allow this to be overridden for testing.
85+
// This should be properly mocked.
86+
var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
9187
targetsByType := make(map[csv.MountSpecType][]string)
9288
for _, filename := range files {
9389
targets, err := loadCSVFile(logger, filename)

internal/platform-support/tegra/csv_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,127 @@
1515
**/
1616

1717
package tegra
18+
19+
import (
20+
"fmt"
21+
"testing"
22+
23+
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
24+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
26+
testlog "github.com/sirupsen/logrus/hooks/test"
27+
"github.com/stretchr/testify/require"
28+
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
30+
)
31+
32+
func TestDiscovererFromCSVFiles(t *testing.T) {
33+
logger, _ := testlog.NewNullLogger()
34+
testCases := []struct {
35+
description string
36+
moutSpecs map[csv.MountSpecType][]string
37+
symlinkLocator lookup.Locator
38+
symlinkChainLocator lookup.Locator
39+
symlinkResolver func(string) (string, error)
40+
expectedError error
41+
expectedMounts []discover.Mount
42+
expectedMountsError error
43+
expectedHooks []discover.Hook
44+
expectedHooksError error
45+
}{
46+
{
47+
// TODO: This current resolves to two mounts that are the same.
48+
// These are deduplicated at a later stage. We could consider deduplicating earlier in the pipeline.
49+
description: "symlink is resolved to target; mounts and symlink are created",
50+
moutSpecs: map[csv.MountSpecType][]string{
51+
"lib": {"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"},
52+
"sym": {"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so"},
53+
},
54+
symlinkLocator: &lookup.LocatorMock{
55+
LocateFunc: func(path string) ([]string, error) {
56+
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
57+
return []string{"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
58+
}
59+
return []string{path}, nil
60+
},
61+
},
62+
symlinkChainLocator: &lookup.LocatorMock{
63+
LocateFunc: func(path string) ([]string, error) {
64+
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
65+
return []string{"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
66+
}
67+
return nil, fmt.Errorf("Unexpected path: %v", path)
68+
},
69+
},
70+
symlinkResolver: func(path string) (string, error) {
71+
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
72+
return "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", nil
73+
}
74+
return path, nil
75+
},
76+
expectedMounts: []discover.Mount{
77+
{
78+
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
79+
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
80+
Options: []string{"ro", "nosuid", "nodev", "bind"},
81+
},
82+
{
83+
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
84+
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
85+
Options: []string{"ro", "nosuid", "nodev", "bind"},
86+
},
87+
},
88+
expectedHooks: []discover.Hook{
89+
{
90+
Lifecycle: "createContainer",
91+
Path: "/usr/bin/nvidia-ctk",
92+
Args: []string{
93+
"nvidia-ctk",
94+
"hook",
95+
"create-symlinks",
96+
"--link",
97+
"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so",
98+
},
99+
},
100+
},
101+
},
102+
}
103+
104+
for _, tc := range testCases {
105+
t.Run(tc.description, func(t *testing.T) {
106+
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
107+
108+
o := tegraOptions{
109+
logger: logger,
110+
nvidiaCTKPath: "/usr/bin/nvidia-ctk",
111+
csvFiles: []string{"dummy"},
112+
symlinkLocator: tc.symlinkLocator,
113+
symlinkChainLocator: tc.symlinkChainLocator,
114+
resolveSymlink: tc.symlinkResolver,
115+
}
116+
117+
d, err := o.newDiscovererFromCSVFiles()
118+
require.ErrorIs(t, err, tc.expectedError)
119+
120+
hooks, err := d.Hooks()
121+
require.ErrorIs(t, err, tc.expectedHooksError)
122+
require.EqualValues(t, tc.expectedHooks, hooks)
123+
124+
mounts, err := d.Mounts()
125+
require.ErrorIs(t, err, tc.expectedMountsError)
126+
require.EqualValues(t, tc.expectedMounts, mounts)
127+
128+
})
129+
}
130+
}
131+
132+
func setGetTargetsFromCSVFiles(ovverride map[csv.MountSpecType][]string) func() {
133+
original := getTargetsFromCSVFiles
134+
getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
135+
return ovverride
136+
}
137+
138+
return func() {
139+
getTargetsFromCSVFiles = original
140+
}
141+
}

internal/platform-support/tegra/symlinks.go

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,29 @@ import (
2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2525
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2626
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
27-
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
2827
)
2928

3029
type symlinkHook struct {
3130
discover.None
3231
logger logger.Interface
33-
driverRoot string
3432
nvidiaCTKPath string
3533
targets []string
3634
mountsFrom discover.Discover
35+
36+
// The following can be overridden for testing
37+
symlinkChainLocator lookup.Locator
38+
resolveSymlink func(string) (string, error)
3739
}
3840

3941
// createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container
40-
func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover {
42+
func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover {
4143
return symlinkHook{
42-
logger: logger,
43-
nvidiaCTKPath: nvidiaCTKPath,
44-
targets: targets,
45-
mountsFrom: mounts,
44+
logger: o.logger,
45+
nvidiaCTKPath: o.nvidiaCTKPath,
46+
targets: targets,
47+
mountsFrom: mounts,
48+
symlinkChainLocator: o.symlinkChainLocator,
49+
resolveSymlink: o.resolveSymlink,
4650
}
4751
}
4852

@@ -105,14 +109,9 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) {
105109

106110
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.
107111
func (d symlinkHook) getSymlinkCandidates() []string {
108-
chainLocator := lookup.NewSymlinkChainLocator(
109-
lookup.WithLogger(d.logger),
110-
lookup.WithRoot(d.driverRoot),
111-
)
112-
113112
var candidates []string
114113
for _, target := range d.targets {
115-
reslovedSymlinkChain, err := chainLocator.Locate(target)
114+
reslovedSymlinkChain, err := d.symlinkChainLocator.Locate(target)
116115
if err != nil {
117116
d.logger.Warningf("Failed to locate symlink %v", target)
118117
continue
@@ -127,7 +126,7 @@ func (d symlinkHook) getCSVFileSymlinks() []string {
127126
created := make(map[string]bool)
128127
// candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain.
129128
for _, candidate := range d.getSymlinkCandidates() {
130-
target, err := symlinks.Resolve(candidate)
129+
target, err := d.resolveSymlink(candidate)
131130
if err != nil {
132131
d.logger.Debugf("Skipping invalid link: %v", err)
133132
continue

internal/platform-support/tegra/tegra.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
2526
)
2627

2728
type tegraOptions struct {
@@ -30,6 +31,12 @@ type tegraOptions struct {
3031
driverRoot string
3132
nvidiaCTKPath string
3233
librarySearchPaths []string
34+
35+
// The following can be overridden for testing
36+
symlinkLocator lookup.Locator
37+
symlinkChainLocator lookup.Locator
38+
// TODO: This should be replaced by a regular mock
39+
resolveSymlink func(string) (string, error)
3340
}
3441

3542
// Option defines a functional option for configuring a Tegra discoverer.
@@ -42,7 +49,27 @@ func New(opts ...Option) (discover.Discover, error) {
4249
opt(o)
4350
}
4451

45-
csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths)
52+
if o.symlinkLocator == nil {
53+
searchPaths := append(o.librarySearchPaths, "/")
54+
o.symlinkLocator = lookup.NewSymlinkLocator(
55+
lookup.WithLogger(o.logger),
56+
lookup.WithRoot(o.driverRoot),
57+
lookup.WithSearchPaths(searchPaths...),
58+
)
59+
}
60+
61+
if o.symlinkChainLocator == nil {
62+
o.symlinkChainLocator = lookup.NewSymlinkChainLocator(
63+
lookup.WithLogger(o.logger),
64+
lookup.WithRoot(o.driverRoot),
65+
)
66+
}
67+
68+
if o.resolveSymlink == nil {
69+
o.resolveSymlink = symlinks.Resolve
70+
}
71+
72+
csvDiscoverer, err := o.newDiscovererFromCSVFiles()
4673
if err != nil {
4774
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
4875
}

0 commit comments

Comments
 (0)