Skip to content

Commit d633684

Browse files
committed
add go-based script to auto-select kernel module type
Signed-off-by: Tariq Ibrahim <[email protected]>
1 parent b119e94 commit d633684

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed

gpu-driver-util/go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module gpu-driver-util
2+
3+
go 1.23.0
4+
5+
require (
6+
github.com/sirupsen/logrus v1.9.3
7+
github.com/urfave/cli/v2 v2.27.5
8+
)
9+
10+
require (
11+
github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect
12+
github.com/russross/blackfriday/v2 v2.1.0 // indirect
13+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
14+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
15+
)

gpu-driver-util/go.sum

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc=
2+
github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
3+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8+
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
9+
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
10+
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
11+
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
12+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
13+
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
14+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
15+
github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w=
16+
github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
17+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
18+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
19+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
20+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
21+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
22+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
23+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
24+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

gpu-driver-util/main.go

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path"
8+
"slices"
9+
"strings"
10+
11+
log "github.com/sirupsen/logrus"
12+
"github.com/urfave/cli/v2"
13+
)
14+
15+
const (
16+
// LogFile is the path for logging
17+
LogFile = "/var/log/gpu-driver-util.log"
18+
// PCIDevicesRoot represents base path for all pci devices under sysfs
19+
PCIDevicesRoot = "/sys/bus/pci/devices"
20+
// NVIDIAVendorID represents the NVIDIA PCI vendor id
21+
NVIDIAVendorID = "0x10de"
22+
// PCIDeviceClassVGA represents the pci device class code for VGA devices
23+
PCIDeviceClassVGA = "0x030000"
24+
// PCIDeviceClassGPU represents the pci device class code for GPU devices
25+
PCIDeviceClassGPU = "0x030200"
26+
// DefaultSupportedGpusJsonPath represents the default install location of the supported-gpus.json file
27+
DefaultSupportedGpusJsonPath = "/usr/share/nvidia-driver-assistant/supported-gpus/supported-gpus.json"
28+
29+
// DriverHintUnknown is used when the gpu device is not found in supported-gpus.json
30+
DriverHintUnknown = "unknown"
31+
// DriverHintOpenRequired is used when the gpu device compulsorily needs the OpenRM kernel modules
32+
DriverHintOpenRequired = "open-required"
33+
// DriverHintProprietaryRequired is used when the gpu device compulsorily needs the proprietary kernel modules
34+
DriverHintProprietaryRequired = "proprietary-required"
35+
// DriverHintAny is used when the gpu device can support either kernel module types
36+
DriverHintAny = "any-supported"
37+
38+
// DriverFeatureKernelOpen indicates that the gpu device supports OpenRM
39+
DriverFeatureKernelOpen = "kernelopen"
40+
// DriverFeatureKernelGSPProprietary indicates that the gpu device has GSP RM and supports proprietary modules
41+
DriverFeatureKernelGSPProprietary = "gsp_proprietary_supported"
42+
43+
// KernelModuleTypeOpen indicates the OpenRM Kernel Modules of the NVIDIA CUDA driver
44+
KernelModuleTypeOpen = "kernel-open"
45+
// KernelModuleTypeProprietary indicates the Closed/Proprietary Kernel Modules of the NVIDIA CUDA driver
46+
KernelModuleTypeProprietary = "kernel"
47+
)
48+
49+
var (
50+
supportedGpusJsonPath string
51+
driverBranch int
52+
)
53+
54+
type GPUDevice struct {
55+
ID string `json:"devid"`
56+
Name string `json:"name"`
57+
LegacyBranch string `json:"legacybranch"`
58+
Features []string `json:"features"`
59+
}
60+
61+
type GPUData struct {
62+
Chips []GPUDevice `json:"chips"`
63+
}
64+
65+
func main() {
66+
logFile, err := initializeLogger()
67+
if err != nil {
68+
log.Fatal(err.Error())
69+
}
70+
defer logFile.Close()
71+
72+
// Create the top-level CLI app
73+
c := cli.NewApp()
74+
c.Name = "gpu-driver-util"
75+
c.Usage = "NVIDIA GPU Driver Utility Application"
76+
c.Version = "0.1.0"
77+
78+
getKernelModule := cli.Command{}
79+
getKernelModule.Name = "get-kernel-module-type"
80+
getKernelModule.Usage = "Automatically determine the kernel module type based on the GPUs detected."
81+
getKernelModule.Action = func(c *cli.Context) error {
82+
return GetKernelModule(c)
83+
}
84+
85+
getKernelModuleFlags := []cli.Flag{
86+
&cli.StringFlag{
87+
Name: "supported-gpus-file",
88+
Aliases: []string{"f"},
89+
Usage: "Specify location of the supported-gpus.json file",
90+
Value: DefaultSupportedGpusJsonPath,
91+
Destination: &supportedGpusJsonPath,
92+
Required: false,
93+
},
94+
&cli.IntFlag{
95+
Name: "driver-branch",
96+
Aliases: []string{"b"},
97+
Usage: "Specify driver branch",
98+
EnvVars: []string{"DRIVER_BRANCH"},
99+
Destination: &driverBranch,
100+
Required: true,
101+
},
102+
}
103+
104+
c.Commands = []*cli.Command{
105+
&getKernelModule,
106+
}
107+
108+
getKernelModule.Flags = append([]cli.Flag{}, getKernelModuleFlags...)
109+
110+
// Run the top-level CLI
111+
if err := c.Run(os.Args); err != nil {
112+
log.Fatal(fmt.Errorf("error running gpu-driver-util: %w", err))
113+
}
114+
115+
}
116+
117+
func initializeLogger() (*os.File, error) {
118+
logFile, err := os.OpenFile(LogFile, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
119+
if err != nil {
120+
return nil, fmt.Errorf("error opening file %s: %w", LogFile, err)
121+
}
122+
// Log as JSON instead of the default ASCII formatter.
123+
log.SetFormatter(&log.JSONFormatter{})
124+
// Output to file instead of stdout
125+
log.SetOutput(logFile)
126+
127+
// Only log the warning severity or above.
128+
log.SetLevel(log.DebugLevel)
129+
return logFile, nil
130+
}
131+
132+
func GetKernelModule(c *cli.Context) error {
133+
log.Infof("Starting the 'get-kernel-module' command of %s", c.App.Name)
134+
gpuDevices, err := getNvidiaGPUs()
135+
136+
log.Debugf("NVIDIA GPU devices found: %v", gpuDevices)
137+
138+
if err != nil {
139+
return err
140+
}
141+
142+
var gpuData GPUData
143+
gpuJSONString, err := os.ReadFile(supportedGpusJsonPath)
144+
if err != nil {
145+
log.Errorf("error opening the supported gpus file %s: %v", supportedGpusJsonPath, err)
146+
return err
147+
}
148+
149+
err = json.Unmarshal(gpuJSONString, &gpuData)
150+
if err != nil {
151+
return err
152+
}
153+
154+
searchMap := buildGPUSearchMap(gpuData)
155+
156+
if len(gpuDevices) > 0 {
157+
kernelModuleType, err := resolveKernelModuleType(gpuDevices, searchMap)
158+
if err != nil {
159+
log.Errorf("error resolving kernel module type: %v", err)
160+
return err
161+
}
162+
fmt.Println(kernelModuleType)
163+
}
164+
return nil
165+
}
166+
167+
func resolveKernelModuleType(gpuDevices []string, searchMap map[string]GPUDevice) (string, error) {
168+
169+
driverHints := getDriverHints(gpuDevices, searchMap)
170+
log.Debugf("driverHints: %v", driverHints)
171+
172+
// NOTE: driver hint "unknown" is assigned to a device that does not have an entry in supported-gpus.json.
173+
// In these cases, we assume that the gpu device is new and unreleased, and we default to OpenRM
174+
requiresOpenRM := slices.Contains(driverHints, DriverHintOpenRequired) || slices.Contains(driverHints, DriverHintUnknown)
175+
requiresProprietary := slices.Contains(driverHints, DriverHintProprietaryRequired)
176+
177+
if requiresOpenRM && requiresProprietary {
178+
return "", fmt.Errorf("unsupported GPU topology")
179+
} else if requiresOpenRM {
180+
return KernelModuleTypeOpen, nil
181+
} else if requiresProprietary {
182+
return KernelModuleTypeProprietary, nil
183+
} else {
184+
return getDriverBranchDefault(driverBranch), nil
185+
}
186+
}
187+
188+
func getDriverHints(gpuDevices []string, searchMap map[string]GPUDevice) []string {
189+
var driverHints []string
190+
191+
for _, gpuDevice := range gpuDevices {
192+
if val, ok := searchMap[gpuDevice]; ok {
193+
gpuFeatures := val.Features
194+
if slices.Contains(gpuFeatures, DriverFeatureKernelGSPProprietary) &&
195+
slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
196+
driverHints = append(driverHints, DriverHintAny)
197+
} else if slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
198+
driverHints = append(driverHints, DriverHintOpenRequired)
199+
} else {
200+
driverHints = append(driverHints, DriverHintProprietaryRequired)
201+
}
202+
} else {
203+
driverHints = append(driverHints, DriverHintUnknown)
204+
}
205+
}
206+
return driverHints
207+
}
208+
209+
func getDriverBranchDefault(driverBranch int) string {
210+
if driverBranch >= 560 {
211+
return KernelModuleTypeOpen
212+
}
213+
return KernelModuleTypeProprietary
214+
}
215+
216+
func buildGPUSearchMap(data GPUData) map[string]GPUDevice {
217+
var gpuMap = make(map[string]GPUDevice)
218+
for _, gpuDevice := range data.Chips {
219+
gpuMap[gpuDevice.ID] = gpuDevice
220+
}
221+
return gpuMap
222+
}
223+
224+
func getNvidiaGPUs() ([]string, error) {
225+
var nvDevices []string
226+
deviceDirs, err := os.ReadDir(PCIDevicesRoot)
227+
if err != nil {
228+
return nil, err
229+
}
230+
231+
for _, device := range deviceDirs {
232+
vendor, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "vendor"))
233+
if err != nil {
234+
return nil, fmt.Errorf("failed to read pci device vendor name for %s: %w", device.Name(), err)
235+
}
236+
if strings.TrimSpace(string(vendor)) != NVIDIAVendorID {
237+
log.Tracef("Skipping device %s as it's not from the NVIDIA vendor", device.Name())
238+
continue
239+
}
240+
class, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "class"))
241+
if err != nil {
242+
return nil, fmt.Errorf("failed to read pci device class name for %s: %w", device.Name(), err)
243+
}
244+
if strings.TrimSpace(string(class)) != PCIDeviceClassVGA && strings.TrimSpace(string(class)) != PCIDeviceClassGPU {
245+
log.Tracef("Skipping NVIDIA device %s as it's not of VGA/GPU device class", device.Name())
246+
continue
247+
}
248+
b, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "device"))
249+
if err != nil {
250+
return nil, fmt.Errorf("failed to read pci device id for %s: %w", device.Name(), err)
251+
}
252+
253+
deviceID, err := sanitizeDeviceID(string(b))
254+
if err != nil {
255+
return nil, fmt.Errorf("found invalid device id for %s: %w", device.Name(), err)
256+
}
257+
258+
nvDevices = append(nvDevices, deviceID)
259+
}
260+
return nvDevices, nil
261+
}
262+
263+
func sanitizeDeviceID(input string) (string, error) {
264+
var result string
265+
result = strings.TrimSpace(input)
266+
267+
if len(result) != 6 {
268+
return "", fmt.Errorf("invalid device id format: %s", input)
269+
}
270+
271+
// We only uppercase the device after the 0x part of the device id string to match the format in supported-gpus.json
272+
// For e.g. "0x1db6" becomes "0x1DB6"
273+
result = fmt.Sprintf("%s%s", result[0:2], strings.ToUpper(result[2:]))
274+
return result, nil
275+
}

0 commit comments

Comments
 (0)