Skip to content

Commit 882fbb3

Browse files
author
Evan Lezar
committed
Merge branch 'add-cdi-auto-mode' into 'main'
Add constants for CDI mode to nvcdi API See merge request nvidia/container-toolkit/container-toolkit!302
2 parents ba50b50 + 2680c45 commit 882fbb3

File tree

4 files changed

+114
-18
lines changed

4 files changed

+114
-18
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ import (
3636
)
3737

3838
const (
39-
discoveryModeAuto = "auto"
40-
discoveryModeNVML = "nvml"
41-
discoveryModeWSL = "wsl"
42-
4339
formatJSON = "json"
4440
formatYAML = "yaml"
4541

@@ -97,8 +93,8 @@ func (m command) build() *cli.Command {
9793
},
9894
&cli.StringFlag{
9995
Name: "discovery-mode",
100-
Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. I mode is set to 'auto' the mode will be determined based on the system configuration.",
101-
Value: discoveryModeAuto,
96+
Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.",
97+
Value: nvcdi.ModeAuto,
10298
Destination: &cfg.discoveryMode,
10399
},
104100
&cli.StringFlag{
@@ -133,9 +129,9 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
133129

134130
cfg.discoveryMode = strings.ToLower(cfg.discoveryMode)
135131
switch cfg.discoveryMode {
136-
case discoveryModeAuto:
137-
case discoveryModeNVML:
138-
case discoveryModeWSL:
132+
case nvcdi.ModeAuto:
133+
case nvcdi.ModeNvml:
134+
case nvcdi.ModeWsl:
139135
default:
140136
return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode)
141137
}

pkg/nvcdi/api.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ import (
2222
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
2323
)
2424

25+
const (
26+
// ModeAuto configures the CDI spec generator to automatically detect the system configuration
27+
ModeAuto = "auto"
28+
// ModeNvml configures the CDI spec generator to use the NVML library.
29+
ModeNvml = "nvml"
30+
// ModeWsl configures the CDI spec generator to generate a WSL spec.
31+
ModeWsl = "wsl"
32+
)
33+
2534
// Interface defines the API for the nvcdi package
2635
type Interface interface {
2736
GetCommonEdits() (*cdi.ContainerEdits, error)

pkg/nvcdi/lib.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ type nvcdilib struct {
3131
deviceNamer DeviceNamer
3232
driverRoot string
3333
nvidiaCTKPath string
34+
35+
infolib info.Interface
3436
}
3537

3638
// New creates a new nvcdi library
@@ -40,7 +42,7 @@ func New(opts ...Option) Interface {
4042
opt(l)
4143
}
4244
if l.mode == "" {
43-
l.mode = "auto"
45+
l.mode = ModeAuto
4446
}
4547
if l.logger == nil {
4648
l.logger = logrus.StandardLogger()
@@ -54,9 +56,12 @@ func New(opts ...Option) Interface {
5456
if l.nvidiaCTKPath == "" {
5557
l.nvidiaCTKPath = "/usr/bin/nvidia-ctk"
5658
}
59+
if l.infolib == nil {
60+
l.infolib = info.New()
61+
}
5762

5863
switch l.resolveMode() {
59-
case "nvml":
64+
case ModeNvml:
6065
if l.nvmllib == nil {
6166
l.nvmllib = nvml.New()
6267
}
@@ -65,7 +70,7 @@ func New(opts ...Option) Interface {
6570
}
6671

6772
return (*nvmllib)(l)
68-
case "wsl":
73+
case ModeWsl:
6974
return (*wsllib)(l)
7075
}
7176

@@ -75,21 +80,19 @@ func New(opts ...Option) Interface {
7580

7681
// resolveMode resolves the mode for CDI spec generation based on the current system.
7782
func (l *nvcdilib) resolveMode() (rmode string) {
78-
if l.mode != "auto" {
83+
if l.mode != ModeAuto {
7984
return l.mode
8085
}
8186
defer func() {
8287
l.logger.Infof("Auto-detected mode as %q", rmode)
8388
}()
8489

85-
nvinfo := info.New()
86-
87-
isWSL, reason := nvinfo.HasDXCore()
90+
isWSL, reason := l.infolib.HasDXCore()
8891
l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason)
8992

9093
if isWSL {
91-
return "wsl"
94+
return ModeWsl
9295
}
9396

94-
return "nvml"
97+
return ModeNvml
9598
}

pkg/nvcdi/lib_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvcdi
18+
19+
import (
20+
"fmt"
21+
"testing"
22+
23+
testlog "github.com/sirupsen/logrus/hooks/test"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestResolveMode(t *testing.T) {
28+
logger, _ := testlog.NewNullLogger()
29+
30+
testCases := []struct {
31+
mode string
32+
// TODO: This should be a proper mock
33+
hasDXCore bool
34+
expected string
35+
}{
36+
{
37+
mode: "auto",
38+
hasDXCore: true,
39+
expected: "wsl",
40+
},
41+
{
42+
mode: "auto",
43+
hasDXCore: false,
44+
expected: "nvml",
45+
},
46+
{
47+
mode: "nvml",
48+
hasDXCore: true,
49+
expected: "nvml",
50+
},
51+
{
52+
mode: "wsl",
53+
hasDXCore: false,
54+
expected: "wsl",
55+
},
56+
{
57+
mode: "not-auto",
58+
hasDXCore: true,
59+
expected: "not-auto",
60+
},
61+
}
62+
63+
for i, tc := range testCases {
64+
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
65+
l := nvcdilib{
66+
logger: logger,
67+
mode: tc.mode,
68+
infolib: infoMock(tc.hasDXCore),
69+
}
70+
71+
require.Equal(t, tc.expected, l.resolveMode())
72+
})
73+
}
74+
}
75+
76+
type infoMock bool
77+
78+
func (i infoMock) HasDXCore() (bool, string) {
79+
return bool(i), ""
80+
}
81+
82+
func (i infoMock) HasNvml() (bool, string) {
83+
panic("should not be called")
84+
}
85+
86+
func (i infoMock) IsTegraSystem() (bool, string) {
87+
panic("should not be called")
88+
}

0 commit comments

Comments
 (0)