Skip to content

Commit 791cd7f

Browse files
committed
Add a hook to update nvidia params
Signed-off-by: Evan Lezar <[email protected]>
1 parent 8713a17 commit 791cd7f

File tree

4 files changed

+218
-0
lines changed

4 files changed

+218
-0
lines changed

cmd/nvidia-cdi-hook/commands/commands.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
2323
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
2424
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
25+
nvidiaparams "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-nvidia-params"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
)
2728

@@ -32,5 +33,6 @@ func New(logger logger.Interface) []*cli.Command {
3233
ldcache.NewCommand(logger),
3334
symlinks.NewCommand(logger),
3435
chmod.NewCommand(logger),
36+
nvidiaparams.NewCommand(logger),
3537
}
3638
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//go:build linux
2+
// +build linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"golang.org/x/sys/unix"
24+
)
25+
26+
func bindMountReadonly(source string, target string) error {
27+
return unix.Mount(source, target, "", unix.MS_BIND|unix.MS_RDONLY|unix.MS_NOSYMFOLLOW, "")
28+
29+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//go:build !linux
2+
// +build !linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"fmt"
24+
)
25+
26+
func bindMountReadonly(source string, target string) error {
27+
return fmt.Errorf("not supported")
28+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvidiaparams
18+
19+
import (
20+
"bufio"
21+
"errors"
22+
"fmt"
23+
"io"
24+
"os"
25+
"path/filepath"
26+
"strings"
27+
28+
"github.com/urfave/cli/v2"
29+
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
32+
)
33+
34+
const (
35+
nvidiaDriverParamsPath = "/proc/driver/nvidia/params"
36+
)
37+
38+
type command struct {
39+
logger logger.Interface
40+
}
41+
42+
type options struct {
43+
containerSpec string
44+
}
45+
46+
// NewCommand constructs an update-nvidia-params command with the specified logger
47+
func NewCommand(logger logger.Interface) *cli.Command {
48+
c := command{
49+
logger: logger,
50+
}
51+
return c.build()
52+
}
53+
54+
// build the update-nvidia-params command
55+
func (m command) build() *cli.Command {
56+
cfg := options{}
57+
58+
// Create the 'update-nvidia-params' command
59+
c := cli.Command{
60+
Name: "update-nvidia-params",
61+
Usage: "Update ldcache in a container by running ldconfig",
62+
Before: func(c *cli.Context) error {
63+
return m.validateFlags(c, &cfg)
64+
},
65+
Action: func(c *cli.Context) error {
66+
return m.run(c, &cfg)
67+
},
68+
}
69+
70+
c.Flags = []cli.Flag{
71+
&cli.StringFlag{
72+
Name: "container-spec",
73+
Hidden: true,
74+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
75+
Destination: &cfg.containerSpec,
76+
},
77+
}
78+
79+
return &c
80+
}
81+
82+
func (m command) validateFlags(c *cli.Context, cfg *options) error {
83+
return nil
84+
}
85+
86+
func (m command) run(c *cli.Context, cfg *options) error {
87+
s, err := oci.LoadContainerState(cfg.containerSpec)
88+
if err != nil {
89+
return fmt.Errorf("failed to load container state: %v", err)
90+
}
91+
92+
containerRoot, err := s.GetContainerRoot()
93+
if err != nil {
94+
return fmt.Errorf("failed to determined container root: %v", err)
95+
}
96+
97+
return m.updateNvidiaParams(containerRoot)
98+
}
99+
100+
func (m command) updateNvidiaParams(containerRoot string) error {
101+
// TODO: Do we need to prefix the driver root?
102+
currentParamsFile, err := os.Open(nvidiaDriverParamsPath)
103+
if errors.Is(err, os.ErrNotExist) {
104+
return nil
105+
}
106+
if err != nil {
107+
return fmt.Errorf("failed to load params file: %w", err)
108+
}
109+
defer currentParamsFile.Close()
110+
111+
return m.updateNvidiaParamsFromReader(currentParamsFile, containerRoot)
112+
}
113+
114+
func (m command) updateNvidiaParamsFromReader(r io.Reader, containerRoot string) error {
115+
var newLines []string
116+
scanner := bufio.NewScanner(r)
117+
var requiresModification bool
118+
for scanner.Scan() {
119+
line := scanner.Text()
120+
if strings.HasPrefix(line, "ModifyDeviceFiles: ") {
121+
if line == "ModifyDeviceFiles: 0" {
122+
m.logger.Debugf("Device node modification is already disabled; exiting")
123+
return nil
124+
}
125+
if line == "ModifyDeviceFiles: 1" {
126+
line = "ModifyDeviceFiles: 0"
127+
requiresModification = true
128+
}
129+
}
130+
newLines = append(newLines, line)
131+
}
132+
if err := scanner.Err(); err != nil {
133+
return fmt.Errorf("failed to read params file: %w", err)
134+
}
135+
136+
if !requiresModification {
137+
return nil
138+
}
139+
140+
containerParamsFile, err := os.CreateTemp("", "nvct-params-*")
141+
if err != nil {
142+
return fmt.Errorf("failed to create temporary params file: %w", err)
143+
}
144+
defer containerParamsFile.Close()
145+
146+
if _, err := containerParamsFile.WriteString(strings.Join(newLines, "\n")); err != nil {
147+
return fmt.Errorf("failed to write temporary params file: %w", err)
148+
}
149+
150+
if err := containerParamsFile.Chmod(0o644); err != nil {
151+
return fmt.Errorf("failed to set permissions on temporary params file: %w", err)
152+
}
153+
154+
if err := bindMountReadonly(containerParamsFile.Name(), filepath.Join(containerRoot, nvidiaDriverParamsPath)); err != nil {
155+
return fmt.Errorf("failed to create temporary parms file mount: %w", err)
156+
}
157+
158+
return nil
159+
}

0 commit comments

Comments
 (0)