Skip to content

Commit b23e42d

Browse files
committed
Add compat-libs hook to ensure that compat libs are used
Signed-off-by: Evan Lezar <[email protected]>
1 parent 1d0777e commit b23e42d

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package compatlibs
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"path/filepath"
23+
"strings"
24+
25+
"github.com/urfave/cli/v2"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
29+
)
30+
31+
type command struct {
32+
logger logger.Interface
33+
}
34+
35+
type options struct {
36+
driverVersion string
37+
containerSpec string
38+
}
39+
40+
// NewCommand constructs an compat-libs command with the specified logger
41+
func NewCommand(logger logger.Interface) *cli.Command {
42+
c := command{
43+
logger: logger,
44+
}
45+
return c.build()
46+
}
47+
48+
// build the compat-libs command
49+
func (m command) build() *cli.Command {
50+
cfg := options{}
51+
52+
// Create the 'compat-libs' command
53+
c := cli.Command{
54+
Name: "compat-libs",
55+
Before: func(c *cli.Context) error {
56+
return m.validateFlags(c, &cfg)
57+
},
58+
Action: func(c *cli.Context) error {
59+
return m.run(c, &cfg)
60+
},
61+
}
62+
63+
c.Flags = []cli.Flag{
64+
&cli.StringFlag{
65+
Name: "driver-version",
66+
Usage: "Specify the host driver version",
67+
Destination: &cfg.driverVersion,
68+
},
69+
&cli.StringFlag{
70+
Name: "container-spec",
71+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
72+
Destination: &cfg.containerSpec,
73+
},
74+
}
75+
76+
return &c
77+
}
78+
79+
func (m command) validateFlags(c *cli.Context, cfg *options) error {
80+
return nil
81+
}
82+
83+
func (m command) run(c *cli.Context, cfg *options) error {
84+
s, err := oci.LoadContainerState(cfg.containerSpec)
85+
if err != nil {
86+
return fmt.Errorf("failed to load container state: %v", err)
87+
}
88+
89+
containerRoot, err := s.GetContainerRoot()
90+
if err != nil {
91+
return fmt.Errorf("failed to determined container root: %v", err)
92+
}
93+
94+
if !root(containerRoot).hasPath("/usr/local/cuda/compat") {
95+
return nil
96+
}
97+
98+
if !root(containerRoot).hasPath("/etc/ld.so.cache") {
99+
// If there is no ldcache in the container, the hook is a no-op
100+
return nil
101+
}
102+
if !root(containerRoot).hasPath("/etc/ld.so.conf.d") {
103+
return nil
104+
}
105+
106+
libs, err := root(containerRoot).glob("/usr/local/cuda/compat/libcuda.so.*.*")
107+
if err != nil {
108+
m.logger.Warningf("Failed to find CUDA compat library: %v", err)
109+
return nil
110+
}
111+
112+
if len(libs) == 0 {
113+
return nil
114+
}
115+
116+
if len(libs) != 1 {
117+
m.logger.Warningf("Unexpected number of CUDA compat libraries: %v", libs)
118+
return nil
119+
}
120+
121+
compatVersion := strings.TrimPrefix(filepath.Base(libs[0]), "libcuda.so.")
122+
compatMajor := strings.SplitN(compatVersion, ".", 2)[0]
123+
driverMajor := strings.SplitN(cfg.driverVersion, ".", 2)[0]
124+
125+
if driverMajor > compatMajor {
126+
return nil
127+
}
128+
129+
return m.createConfig(containerRoot, []string{"/usr/local/cuda/compat"})
130+
}
131+
132+
type root string
133+
134+
func (r root) hasPath(path string) bool {
135+
_, err := os.Stat(filepath.Join(string(r), path))
136+
if err != nil && os.IsNotExist(err) {
137+
return false
138+
}
139+
return true
140+
}
141+
142+
func (r root) glob(pattern string) ([]string, error) {
143+
return filepath.Glob(filepath.Join(string(r), pattern))
144+
}
145+
146+
// createConfig creates (or updates) /etc/ld.so.conf.d/00-compat-<RANDOM_STRING>.conf in the container
147+
// to include the required paths.
148+
// Note that the 00-nvcr prefix is chosen to ensure that these libraries have
149+
// a higher precedence than other libraries on the system but are applied AFTER
150+
// 00-cuda-compat.conf.
151+
func (m command) createConfig(root string, folders []string) error {
152+
if len(folders) == 0 {
153+
m.logger.Debugf("No folders to add to /etc/ld.so.conf")
154+
return nil
155+
}
156+
157+
if err := os.MkdirAll(filepath.Join(root, "/etc/ld.so.conf.d"), 0755); err != nil {
158+
return fmt.Errorf("failed to create ld.so.conf.d: %v", err)
159+
}
160+
161+
configFile, err := os.CreateTemp(filepath.Join(root, "/etc/ld.so.conf.d"), "00-compat-*.conf")
162+
if err != nil {
163+
return fmt.Errorf("failed to create config file: %v", err)
164+
}
165+
defer configFile.Close()
166+
167+
m.logger.Debugf("Adding folders %v to %v", folders, configFile.Name())
168+
169+
configured := make(map[string]bool)
170+
for _, folder := range folders {
171+
if configured[folder] {
172+
continue
173+
}
174+
_, err = configFile.WriteString(fmt.Sprintf("%s\n", folder))
175+
if err != nil {
176+
return fmt.Errorf("failed to update ld.so.conf.d: %v", err)
177+
}
178+
configured[folder] = true
179+
}
180+
181+
// The created file needs to be world readable for the cases where the container is run as a non-root user.
182+
if err := os.Chmod(configFile.Name(), 0644); err != nil {
183+
return fmt.Errorf("failed to chmod config file: %v", err)
184+
}
185+
186+
return nil
187+
}

0 commit comments

Comments
 (0)