Skip to content

Commit 5496a0f

Browse files
committed
Evan's comments
Signed-off-by: Evan Lezar <[email protected]>
1 parent f3128d1 commit 5496a0f

File tree

10 files changed

+92
-139
lines changed

10 files changed

+92
-139
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,15 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
258258

259259
privileged := isPrivileged(s)
260260

261-
opts := []image.Option{
261+
i, err := image.New(
262262
image.WithEnv(s.Process.Env),
263263
image.WithMounts(s.Mounts),
264264
image.WithPrivileged(privileged),
265265
image.WithDisableRequire(hookConfig.DisableRequire),
266266
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
267267
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
268268
image.WithAdditionalVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
269-
}
270-
271-
i, err := image.New(opts...)
269+
)
272270
if err != nil {
273271
log.Panicln(err)
274272
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -477,18 +477,11 @@ func TestGetNvidiaConfig(t *testing.T) {
477477
}
478478
for _, tc := range tests {
479479
t.Run(tc.description, func(t *testing.T) {
480-
opts := []image.Option{
480+
image, _ := image.New(
481481
image.WithEnvMap(tc.env),
482482
image.WithPrivileged(tc.privileged),
483-
image.WithAcceptEnvvarUnprivileged(true),
484-
}
485-
486-
if tc.hookConfig != nil {
487-
if tc.hookConfig.SwarmResource != "" {
488-
opts = append(opts, image.WithAdditionalVisibleDevicesEnvVars(tc.hookConfig.SwarmResource))
489-
}
490-
}
491-
image, _ := image.New(opts...)
483+
image.WithAdditionalVisibleDevicesEnvVars(tc.hookConfig.getSwarmResourceEnvvars()...),
484+
)
492485

493486
// Wrap the call to getNvidiaConfig() in a closure.
494487
var cfg *nvidiaConfig
@@ -852,15 +845,13 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
852845

853846
for _, tc := range tests {
854847
t.Run(tc.description, func(t *testing.T) {
855-
opts := []image.Option{
848+
image, _ := image.New(
856849
image.WithEnvMap(tc.env),
857850
image.WithPrivileged(true),
858851
image.WithAcceptDeviceListAsVolumeMounts(false),
859852
image.WithAcceptEnvvarUnprivileged(false),
860853
image.WithAdditionalVisibleDevicesEnvVars(tc.swarmResourceEnvvars...),
861-
}
862-
863-
image, _ := image.New(opts...)
854+
)
864855

865856
devices := image.VisibleDevices()
866857
require.EqualValues(t, tc.expectedDevices, devices)

cmd/nvidia-container-runtime-hook/hook_config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (c hookConfig) getConfigOption(fieldName string) string {
8888

8989
// getSwarmResourceEnvvars returns the swarm resource envvars for the config.
9090
func (c *hookConfig) getSwarmResourceEnvvars() []string {
91-
if c.SwarmResource == "" {
91+
if c == nil || c.SwarmResource == "" {
9292
return nil
9393
}
9494

internal/config/image/builder.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"strings"
2222

2323
"github.com/opencontainers/runtime-spec/specs-go"
24+
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2426
)
2527

2628
type builder struct {
@@ -34,12 +36,21 @@ type Option func(*builder) error
3436

3537
// New creates a new CUDA image from the input options.
3638
func New(opt ...Option) (CUDA, error) {
37-
b := &builder{}
39+
b := &builder{
40+
CUDA: CUDA{
41+
acceptEnvvarUnprivileged: true,
42+
},
43+
}
3844
for _, o := range opt {
3945
if err := o(b); err != nil {
4046
return CUDA{}, err
4147
}
4248
}
49+
50+
if b.logger == nil {
51+
b.logger = logger.New()
52+
}
53+
4354
if b.env == nil {
4455
b.env = make(map[string]string)
4556
}
@@ -103,6 +114,14 @@ func WithEnvMap(env map[string]string) Option {
103114
}
104115
}
105116

117+
// WithLogger sets the logger to use when creating the CUDA image.
118+
func WithLogger(logger logger.Interface) Option {
119+
return func(b *builder) error {
120+
b.logger = logger
121+
return nil
122+
}
123+
}
124+
106125
// WithMounts sets the mounts associated with the CUDA image.
107126
func WithMounts(mounts []specs.Mount) Option {
108127
return func(b *builder) error {
@@ -126,16 +145,7 @@ func WithPrivileged(isPrivileged bool) Option {
126145
// - Concatenate all of this
127146
func WithAdditionalVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option {
128147
return func(b *builder) error {
129-
var result []string
130-
for _, v := range visibleDevicesEnvVars {
131-
for _, c := range strings.Split(v, ",") {
132-
trimmed := strings.TrimSpace(c)
133-
if trimmed != "" {
134-
result = append(result, trimmed)
135-
}
136-
}
137-
}
138-
b.additionalVisibleDevicesEnvVars = result
148+
b.additionalVisibleDevicesEnvVars = visibleDevicesEnvVars
139149
return nil
140150
}
141151
}

internal/config/image/builder_test.go

Lines changed: 0 additions & 69 deletions
This file was deleted.

internal/config/image/cuda_image.go

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"github.com/opencontainers/runtime-spec/specs-go"
2626
"golang.org/x/mod/semver"
2727
"tags.cncf.io/container-device-interface/pkg/parser"
28+
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2830
)
2931

3032
const (
@@ -38,30 +40,33 @@ const (
3840
// a map of environment variable to values that can be used to perform lookups
3941
// such as requirements.
4042
type CUDA struct {
41-
additionalVisibleDevicesEnvVars []string
42-
43-
env map[string]string
43+
logger logger.Interface
4444

45+
env map[string]string
4546
mounts []specs.Mount
4647

47-
acceptDeviceListAsVolumeMounts bool
48-
acceptEnvvarUnprivileged bool
49-
isPrivileged bool
48+
isPrivileged bool
49+
50+
additionalVisibleDevicesEnvVars []string
51+
acceptDeviceListAsVolumeMounts bool
52+
acceptEnvvarUnprivileged bool
5053
}
5154

5255
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
5356
// The process environment is read (if present) to construc the CUDA Image.
54-
func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) {
57+
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
5558
var env []string
5659
if spec != nil && spec.Process != nil {
5760
env = spec.Process.Env
5861
}
5962

60-
return New(
63+
specOpts := []Option{
6164
WithEnv(env),
6265
WithMounts(spec.Mounts),
6366
WithPrivileged(IsPrivileged((*OCISpec)(spec))),
64-
)
67+
}
68+
69+
return New(append(opts, specOpts...)...)
6570
}
6671

6772
// newCUDAImageFromEnv creates a CUDA image from the input environment. The environment
@@ -163,7 +168,7 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
163168

164169
// GetDriverCapabilities returns the requested driver capabilities.
165170
func (i CUDA) GetDriverCapabilities() DriverCapabilities {
166-
env := i.Getenv(EnvVarNvidiaDriverCapabilities)
171+
env := i.env[EnvVarNvidiaDriverCapabilities]
167172

168173
capabilities := make(DriverCapabilities)
169174
for _, c := range strings.Split(env, ",") {
@@ -174,7 +179,7 @@ func (i CUDA) GetDriverCapabilities() DriverCapabilities {
174179
}
175180

176181
func (i CUDA) legacyVersion() (string, error) {
177-
cudaVersion := i.Getenv(EnvVarCudaVersion)
182+
cudaVersion := i.env[EnvVarCudaVersion]
178183
majorMinor, err := parseMajorMinorVersion(cudaVersion)
179184
if err != nil {
180185
return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err)
@@ -224,9 +229,42 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
224229
return hasCDIdevice
225230
}
226231

227-
// VisibleDevicesFromEnvVar returns the set of visible devices requested through
228-
// the NVIDIA_VISIBLE_DEVICES environment variable or any variables specified
229-
// in visibleDevicesEnvVars.
232+
// VisibleDevices returns a list of devices requested in the container image.
233+
// If volume mount requests are enabled these are returned if requested,
234+
// otherwise device requests through environment variables are considered.
235+
// In cases where environment variable requests required privileged containers,
236+
// such devices requests are ignored.
237+
func (i CUDA) VisibleDevices() []string {
238+
// If enabled, try and get the device list from volume mounts first
239+
if i.acceptDeviceListAsVolumeMounts {
240+
volumeMountDeviceRequests := i.visibleDevicesFromMounts()
241+
if len(volumeMountDeviceRequests) > 0 {
242+
return volumeMountDeviceRequests
243+
}
244+
}
245+
246+
// Get the Fallback to reading from the environment variable if privileges are correct
247+
envVarDeviceRequests := i.VisibleDevicesFromEnvVar()
248+
if len(envVarDeviceRequests) == 0 {
249+
return nil
250+
}
251+
252+
// If the container is privileged, or environment variable requests are
253+
// allowed for unprivileged containers, these devices are returned.
254+
if i.isPrivileged || i.acceptEnvvarUnprivileged {
255+
return envVarDeviceRequests
256+
}
257+
258+
// We log a warning if we are ignoring the environment variable requests.
259+
i.logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES in unprivileged container")
260+
261+
return nil
262+
}
263+
264+
// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
265+
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
266+
// are used to determine the visible devices. If this is not the caes, the
267+
// NVIDIA_VISIBLE_DEVICES environment variable is used.
230268
func (i CUDA) VisibleDevicesFromEnvVar() []string {
231269
for _, envVar := range i.additionalVisibleDevicesEnvVars {
232270
if i.HasEnvvar(envVar) {
@@ -284,28 +322,6 @@ func (i CUDA) DevicesFromMounts() []string {
284322
return devices
285323
}
286324

287-
func (i CUDA) VisibleDevices() []string {
288-
// If enabled, try and get the device list from volume mounts first
289-
if i.acceptDeviceListAsVolumeMounts {
290-
devices := i.visibleDevicesFromMounts()
291-
if len(devices) > 0 {
292-
return devices
293-
}
294-
}
295-
296-
// Fallback to reading from the environment variable if privileges are correct
297-
devices := i.VisibleDevicesFromEnvVar()
298-
if len(devices) == 0 {
299-
return nil
300-
}
301-
302-
if i.isPrivileged || i.acceptEnvvarUnprivileged {
303-
return devices
304-
}
305-
306-
return nil
307-
}
308-
309325
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
310326
func (i CUDA) CDIDevicesFromMounts() []string {
311327
var devices []string

internal/config/image/privileged.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
package image
1818

19-
import "github.com/opencontainers/runtime-spec/specs-go"
19+
import (
20+
"github.com/opencontainers/runtime-spec/specs-go"
21+
)
2022

2123
const (
2224
capSysAdmin = "CAP_SYS_ADMIN"

internal/modifier/cdi.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
7979
return annotationDevices, nil
8080
}
8181

82-
container, err := image.NewCUDAImageFromSpec(rawSpec)
82+
container, err := image.NewCUDAImageFromSpec(
83+
rawSpec,
84+
image.WithLogger(logger),
85+
)
8386
if err != nil {
8487
return nil, err
8588
}

internal/modifier/graphics_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ func TestGraphicsModifier(t *testing.T) {
9292
t.Run(tc.description, func(t *testing.T) {
9393
image, _ := image.New(
9494
image.WithEnvMap(tc.envmap),
95-
image.WithAcceptEnvvarUnprivileged(true),
9695
)
9796
required, _ := requiresGraphicsModifier(image)
9897
require.EqualValues(t, tc.expectedRequired, required)

internal/runtime/runtime_factory.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
7070
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
7171
}
7272

73-
image, err := image.NewCUDAImageFromSpec(rawSpec)
73+
image, err := image.NewCUDAImageFromSpec(
74+
rawSpec,
75+
image.WithLogger(logger),
76+
)
7477
if err != nil {
7578
return nil, err
7679
}

0 commit comments

Comments
 (0)