Skip to content

Commit 90c6cfb

Browse files
committed
Merge branch 'fix-cgroup-root' into 'master'
Fix bug where cgroup mount prefix not stripped from cgroup root See merge request nvidia/container-toolkit/libnvidia-container!139
2 parents 14b0453 + 162f9ba commit 90c6cfb

File tree

5 files changed

+61
-38
lines changed

5 files changed

+61
-38
lines changed

src/cgroup.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ nvcgo_find_device_cgroup_path_1_svc(ptr_t ctxptr, int dev_cg_version, char *proc
108108
struct error *err = (struct error[]){0};
109109
struct nvcgo *nvcgo = (struct nvcgo *)ctxptr;
110110
char path[PATH_MAX] = {};
111-
char *cgroup_path = NULL;
111+
char *cgroup_mount_prefix = NULL;
112112
char *cgroup_mount = NULL;
113113
char *cgroup_root = NULL;
114+
char *cgroup_path = NULL;
114115
char *rerr = NULL;
115116
int rv = -1;
116117

@@ -127,12 +128,12 @@ nvcgo_find_device_cgroup_path_1_svc(ptr_t ctxptr, int dev_cg_version, char *proc
127128
if (perm_set_capabilities(err, CAP_EFFECTIVE, ecaps[NVC_CONTAINER], ecaps_size(NVC_CONTAINER)) < 0)
128129
goto fail;
129130

130-
if ((rv = nvcgo->api.GetDeviceCGroupMountPath(dev_cg_version, proc_root, mp_pid, &cgroup_mount, &rerr)) < 0) {
131+
if ((rv = nvcgo->api.GetDeviceCGroupMountPath(dev_cg_version, proc_root, mp_pid, &cgroup_mount_prefix, &cgroup_mount, &rerr)) < 0) {
131132
error_setx(err, "failed to get device cgroup mount path: %s", rerr);
132133
goto fail;
133134
}
134135

135-
if ((rv = nvcgo->api.GetDeviceCGroupRootPath(dev_cg_version, proc_root, rp_pid, &cgroup_root, &rerr)) < 0) {
136+
if ((rv = nvcgo->api.GetDeviceCGroupRootPath(dev_cg_version, proc_root, cgroup_mount_prefix, rp_pid, &cgroup_root, &rerr)) < 0) {
136137
error_setx(err, "failed to get device cgroup root path: %s", rerr);
137138
goto fail;
138139
}

src/nvcgo/internal/cgroup/api.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ import (
2929
type DeviceRule = specs.LinuxDeviceCgroup
3030

3131
type Interface interface {
32-
GetDeviceCGroupMountPath(procRootPath string, pid int) (string, error)
33-
GetDeviceCGroupRootPath(procRootPath string, pid int) (string, error)
32+
GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error)
33+
GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error)
3434
AddDeviceRules(cgroupPath string, devices []DeviceRule) error
3535
}
3636

src/nvcgo/internal/cgroup/v1.go

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ import (
2424
"strings"
2525
)
2626

27-
// GetDeviceCGroupMountPath returns the mount path for the device cgroup controller associated with pid
28-
func (c *cgroupv1) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, error) {
27+
// GetDeviceCGroupMountPath returns the mount path (and its prefix) for the device cgroup controller associated with pid
28+
func (c *cgroupv1) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error) {
2929
// Open the pid's mountinfo file in /proc.
3030
path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "mountinfo"), pid)
3131
file, err := os.Open(path)
3232
if err != nil {
33-
return "", err
33+
return "", "", err
3434
}
3535
defer file.Close()
3636

@@ -43,7 +43,7 @@ func (c *cgroupv1) GetDeviceCGroupMountPath(procRootPath string, pid int) (strin
4343
// Split each entry by '[space]'
4444
parts := strings.Split(scanner.Text(), " ")
4545
if len(parts) < 5 {
46-
return "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text())
46+
return "", "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text())
4747
}
4848
// Look for an entry with cgroup as the mount type.
4949
if parts[len(parts)-3] != "cgroup" {
@@ -53,15 +53,21 @@ func (c *cgroupv1) GetDeviceCGroupMountPath(procRootPath string, pid int) (strin
5353
if filepath.Base(parts[4]) != "devices" {
5454
continue
5555
}
56-
// Return the 4th element as the mount point of the devices cgroup.
57-
return parts[4], nil
56+
// Make sure the mount prefix is not a relative path.
57+
if strings.HasPrefix(parts[3], "/..") {
58+
return "", "", fmt.Errorf("relative path in mount prefix: %v", parts[3])
59+
}
60+
// Return the 3rd element as the prefix of the mount point for
61+
// the devices cgroup and the 4th element as the mount point of
62+
// the devices cgroup itself.
63+
return parts[3], parts[4], nil
5864
}
5965

60-
return "", fmt.Errorf("no cgroup filesystem mounted for the devices subsytem in mountinfo file")
66+
return "", "", fmt.Errorf("no cgroup filesystem mounted for the devices subsytem in mountinfo file")
6167
}
6268

63-
// GetDeviceCGroupMountPath returns the root path for the device cgroup controller associated with pid
64-
func (c *cgroupv1) GetDeviceCGroupRootPath(procRootPath string, pid int) (string, error) {
69+
// GetDeviceCGroupRootPath returns the root path for the device cgroup controller associated with pid
70+
func (c *cgroupv1) GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error) {
6571
// Open the pid's cgroup file in /proc.
6672
path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "cgroup"), pid)
6773
file, err := os.Open(path)
@@ -81,12 +87,16 @@ func (c *cgroupv1) GetDeviceCGroupRootPath(procRootPath string, pid int) (string
8187
if len(parts) != 3 {
8288
return "", fmt.Errorf("malformed cgroup entry: %v", scanner.Text())
8389
}
84-
// Look for the devices subsystem in the 2st element.
90+
// Look for the devices subsystem in the 1st element.
8591
if parts[1] != "devices" {
8692
continue
8793
}
88-
// Return the cgroup root from the 2nd element.
89-
return parts[2], nil
94+
// Return the cgroup root from the 2nd element
95+
// (with the prefix possibly stripped off).
96+
if prefix == "/" {
97+
return parts[2], nil
98+
}
99+
return strings.TrimPrefix(parts[2], prefix), nil
90100
}
91101

92102
return "", fmt.Errorf("no devices cgroup entries found")
@@ -96,17 +106,17 @@ func (c *cgroupv1) GetDeviceCGroupRootPath(procRootPath string, pid int) (string
96106
func (c *cgroupv1) AddDeviceRules(cgroupPath string, rules []DeviceRule) error {
97107
// Loop through all rules in the set of device rules and add that rule to the device.
98108
for _, rule := range rules {
99-
err := c.addDeviceRule(cgroupPath, &rule)
100-
if err != nil {
101-
return err
102-
}
109+
err := c.addDeviceRule(cgroupPath, &rule)
110+
if err != nil {
111+
return err
112+
}
103113
}
104114

105115
return nil
106116
}
107117

108118
func (c *cgroupv1) addDeviceRule(cgroupPath string, rule *DeviceRule) error {
109-
// Check the major/minor numbers of the device in the device rule.
119+
// Check the major/minor numbers of the device in the device rule.
110120
if rule.Major == nil {
111121
return fmt.Errorf("no major set in device rule")
112122
}
@@ -126,7 +136,7 @@ func (c *cgroupv1) addDeviceRule(cgroupPath string, rule *DeviceRule) error {
126136
if err != nil {
127137
return err
128138
}
129-
defer file.Close()
139+
defer file.Close()
130140

131141
// Write the device rule into the file.
132142
_, err = file.WriteString(fmt.Sprintf("%s %d:%d %s", rule.Type, *rule.Major, *rule.Minor, rule.Access))

src/nvcgo/internal/cgroup/v2.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ const (
3232
BpfProgramLicense = "Apache"
3333
)
3434

35-
// GetDeviceCGroupMountPath returns the mount path for the device cgroup controller associated with pid
36-
func (c *cgroupv2) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, error) {
35+
// GetDeviceCGroupMountPath returns the mount path (and its prefix) for the device cgroup controller associated with pid
36+
func (c *cgroupv2) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error) {
3737
// Open the pid's mountinfo file in /proc.
3838
path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "mountinfo"), pid)
3939
file, err := os.Open(path)
4040
if err != nil {
41-
return "", err
41+
return "", "", err
4242
}
4343
defer file.Close()
4444

@@ -51,21 +51,27 @@ func (c *cgroupv2) GetDeviceCGroupMountPath(procRootPath string, pid int) (strin
5151
// Split each entry by '[space]'
5252
parts := strings.Split(scanner.Text(), " ")
5353
if len(parts) < 5 {
54-
return "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text())
54+
return "", "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text())
5555
}
5656
// Look for an entry with cgroup2 as the mount type.
5757
if parts[len(parts)-3] != "cgroup2" {
5858
continue
5959
}
60-
// Return the 4th element as the moint point of the devices cgroup.
61-
return parts[4], nil
60+
// Make sure the mount prefix is not a relative path.
61+
if strings.HasPrefix(parts[3], "/..") {
62+
return "", "", fmt.Errorf("relative path in mount prefix: %v", parts[3])
63+
}
64+
// Return the 3rd element as the prefix of the mount point for
65+
// the devices cgroup and the 4th element as the mount point of
66+
// the devices cgroup itself.
67+
return parts[3], parts[4], nil
6268
}
6369

64-
return "", fmt.Errorf("no cgroup2 filesystem in mountinfo file")
70+
return "", "", fmt.Errorf("no cgroup2 filesystem in mountinfo file")
6571
}
6672

67-
// GetDeviceCGroupMountPath returns the root path for the device cgroup controller associated with pid
68-
func (c *cgroupv2) GetDeviceCGroupRootPath(procRootPath string, pid int) (string, error) {
73+
// GetDeviceCGroupRootPath returns the root path for the device cgroup controller associated with pid
74+
func (c *cgroupv2) GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error) {
6975
// Open the pid's cgroup file in /proc.
7076
path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "cgroup"), pid)
7177
file, err := os.Open(path)
@@ -85,11 +91,16 @@ func (c *cgroupv2) GetDeviceCGroupRootPath(procRootPath string, pid int) (string
8591
if len(parts) != 3 {
8692
return "", fmt.Errorf("malformed cgroup entry: %v", scanner.Text())
8793
}
94+
// Look for the (empty) subsystem in the 1st element.
8895
if parts[1] != "" {
8996
continue
9097
}
91-
// Return the cgroup root from the 2nd element.
92-
return parts[2], nil
98+
// Return the cgroup root from the 2nd element
99+
// (with the prefix possibly stripped off).
100+
if prefix == "/" {
101+
return parts[2], nil
102+
}
103+
return strings.TrimPrefix(parts[2], prefix), nil
93104
}
94105

95106
return "", fmt.Errorf("no cgroupv2 entries in file")

src/nvcgo/main.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,32 +54,33 @@ func GetDeviceCGroupVersion(rootPath *C.char, pid C.pid_t, version *C.int, rerr
5454
}
5555

5656
//export GetDeviceCGroupMountPath
57-
func GetDeviceCGroupMountPath(version C.int, procRootPath *C.char, pid C.pid_t, cgroupMountPath **C.char, rerr **C.char) C.int {
57+
func GetDeviceCGroupMountPath(version C.int, procRootPath *C.char, pid C.pid_t, cgroupMountPath **C.char, cgroupRootPrefix **C.char, rerr **C.char) C.int {
5858
api, err := cgroup.New(int(version))
5959
if err != nil {
6060
*rerr = C.CString(fmt.Sprintf("unable to create cgroupv%v interface: %v", version, err))
6161
return -1
6262
}
6363

64-
p, err := api.GetDeviceCGroupMountPath(C.GoString(procRootPath), int(pid))
64+
p, r, err := api.GetDeviceCGroupMountPath(C.GoString(procRootPath), int(pid))
6565
if err != nil {
6666
*rerr = C.CString(err.Error())
6767
return -1
6868
}
6969
*cgroupMountPath = C.CString(p)
70+
*cgroupRootPrefix= C.CString(r)
7071

7172
return 0
7273
}
7374

7475
//export GetDeviceCGroupRootPath
75-
func GetDeviceCGroupRootPath(version C.int, procRootPath *C.char, pid C.int, cgroupRootPath **C.char, rerr **C.char) C.int {
76+
func GetDeviceCGroupRootPath(version C.int, procRootPath *C.char, cgroupRootPrefix *C.char, pid C.int, cgroupRootPath **C.char, rerr **C.char) C.int {
7677
api, err := cgroup.New(int(version))
7778
if err != nil {
7879
*rerr = C.CString(fmt.Sprintf("unable to create cgroupv%v interface: %v", version, err))
7980
return -1
8081
}
8182

82-
p, err := api.GetDeviceCGroupRootPath(C.GoString(procRootPath), int(pid))
83+
p, err := api.GetDeviceCGroupRootPath(C.GoString(procRootPath), C.GoString(cgroupRootPrefix), int(pid))
8384
if err != nil {
8485
*rerr = C.CString(err.Error())
8586
return -1

0 commit comments

Comments
 (0)