Skip to content

Commit 907f685

Browse files
committed
add go-based script to auto-select kernel module type
Signed-off-by: Tariq Ibrahim <[email protected]>
1 parent b119e94 commit 907f685

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

gpu-driver-util/src/go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module driver-assistant
2+
3+
go 1.23.0
4+
5+
require (
6+
github.com/sirupsen/logrus v1.9.3
7+
github.com/urfave/cli/v2 v2.27.5
8+
)
9+
10+
require (
11+
github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect
12+
github.com/russross/blackfriday/v2 v2.1.0 // indirect
13+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
14+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
15+
)

gpu-driver-util/src/go.sum

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc=
2+
github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
3+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8+
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
9+
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
10+
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
11+
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
12+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
13+
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
14+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
15+
github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w=
16+
github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
17+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
18+
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
19+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
20+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
21+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
22+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
23+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
24+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

gpu-driver-util/src/main.go

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path"
8+
"slices"
9+
"strings"
10+
11+
log "github.com/sirupsen/logrus"
12+
"github.com/urfave/cli/v2"
13+
)
14+
15+
const (
16+
// LogFile is the path for logging
17+
LogFile = "/var/log/gpu-driver-util.log"
18+
// PCIDevicesRoot represents base path for all pci devices under sysfs
19+
PCIDevicesRoot = "/sys/bus/pci/devices"
20+
// NVIDIAVendorID represents the NVIDIA PCI vendor id
21+
NVIDIAVendorID = "0x10de"
22+
// PCIDeviceClassVGA represents the pci device class code for VGA devices
23+
PCIDeviceClassVGA = "0x030000"
24+
// PCIDeviceClassGPU represents the pci device class code for GPU devices
25+
PCIDeviceClassGPU = "0x030200"
26+
// DefaultSupportedGpusJsonPath represents the default install location of the supported-gpus.json file
27+
DefaultSupportedGpusJsonPath = "/usr/share/nvidia-driver-assistant/supported-gpus/supported-gpus.json"
28+
29+
// DriverHintUnknown is used when the gpu device is not found in supported-gpus.json
30+
DriverHintUnknown = "unknown"
31+
// DriverHintOpenRequired is used when the gpu device compulsorily needs the OpenRM kernel modules
32+
DriverHintOpenRequired = "open-required"
33+
// DriverHintProprietaryRequired is used when the gpu device compulsorily needs the proprietary kernel modules
34+
DriverHintProprietaryRequired = "proprietary-required"
35+
// DriverHintAny is used when the gpu device can support either kernel module types
36+
DriverHintAny = "any-supported"
37+
38+
// DriverFeatureKernelOpen indicates that the gpu device supports OpenRM
39+
DriverFeatureKernelOpen = "kernelopen"
40+
// DriverFeatureKernelGSPProprietary indicates that the gpu device has GSP RM and supports proprietary modules
41+
DriverFeatureKernelGSPProprietary = "gsp_proprietary_supported"
42+
43+
// KernelModuleTypeOpen indicates the OpenRM Kernel Modules of the NVIDIA CUDA driver
44+
KernelModuleTypeOpen = "kernel-open"
45+
// KernelModuleTypeProprietary indicates the Closed/Proprietary Kernel Modules of the NVIDIA CUDA driver
46+
KernelModuleTypeProprietary = "kernel"
47+
)
48+
49+
var (
50+
supportedGpusJsonPath string
51+
driverBranch int
52+
)
53+
54+
type GPUDevice struct {
55+
ID string `json:"devid"`
56+
Name string `json:"name"`
57+
LegacyBranch string `json:"legacybranch"`
58+
Features []string `json:"features"`
59+
}
60+
61+
type GPUData struct {
62+
Chips []GPUDevice `json:"chips"`
63+
}
64+
65+
func main() {
66+
logFile, err := initializeLogger()
67+
if err != nil {
68+
log.Fatal(err.Error())
69+
}
70+
defer logFile.Close()
71+
72+
// Create the top-level CLI app
73+
c := cli.NewApp()
74+
c.Name = "gpu-driver-util"
75+
c.Usage = "NVIDIA GPU Driver Utility Application"
76+
c.Version = "0.1.0"
77+
78+
getKernelModule := cli.Command{}
79+
getKernelModule.Name = "get-kernel-module-type"
80+
getKernelModule.Usage = "Automatically determine the kernel module type based on the GPUs detected."
81+
getKernelModule.Action = func(c *cli.Context) error {
82+
return GetKernelModule(c)
83+
}
84+
85+
getKernelModuleFlags := []cli.Flag{
86+
&cli.StringFlag{
87+
Name: "supported-gpus-file",
88+
Aliases: []string{"f"},
89+
Usage: "Specify location of the supported-gpus.json file",
90+
Value: DefaultSupportedGpusJsonPath,
91+
Destination: &supportedGpusJsonPath,
92+
Required: false,
93+
},
94+
&cli.IntFlag{
95+
Name: "driver-branch",
96+
Aliases: []string{"b"},
97+
Usage: "Specify driver branch",
98+
EnvVars: []string{"DRIVER_BRANCH"},
99+
Destination: &driverBranch,
100+
Required: true,
101+
},
102+
}
103+
104+
c.Commands = []*cli.Command{
105+
&getKernelModule,
106+
}
107+
108+
getKernelModule.Flags = append([]cli.Flag{}, getKernelModuleFlags...)
109+
110+
// Run the top-level CLI
111+
if err := c.Run(os.Args); err != nil {
112+
log.Fatal(fmt.Errorf("error running gpu-driver-util: %w", err))
113+
}
114+
115+
}
116+
117+
func initializeLogger() (*os.File, error) {
118+
logFile, err := os.OpenFile(LogFile, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
119+
if err != nil {
120+
return nil, fmt.Errorf("error opening file %s: %w", LogFile, err)
121+
}
122+
// Log as JSON instead of the default ASCII formatter.
123+
log.SetFormatter(&log.JSONFormatter{})
124+
// Output to file instead of stdout
125+
log.SetOutput(logFile)
126+
127+
// Only log the warning severity or above.
128+
log.SetLevel(log.DebugLevel)
129+
return logFile, nil
130+
}
131+
132+
func GetKernelModule(c *cli.Context) error {
133+
log.Infof("Starting the 'get-kernel-module' command of %s", c.App.Name)
134+
gpuDevices, err := getNvidiaGPUs()
135+
if err != nil {
136+
return err
137+
}
138+
139+
var gpuData GPUData
140+
gpuJSONString, err := os.ReadFile(supportedGpusJsonPath)
141+
if err != nil {
142+
log.Errorf("error opening the supported gpus file %s: %v", supportedGpusJsonPath, err)
143+
return err
144+
}
145+
146+
err = json.Unmarshal(gpuJSONString, &gpuData)
147+
if err != nil {
148+
return err
149+
}
150+
151+
searchMap := buildGPUSearchMap(gpuData)
152+
153+
if len(gpuDevices) > 0 {
154+
kernelModuleType, err := resolveKernelModuleType(gpuDevices, searchMap)
155+
if err != nil {
156+
log.Errorf("error resolving kernel module type: %v", err)
157+
return err
158+
}
159+
fmt.Println(kernelModuleType)
160+
}
161+
return nil
162+
}
163+
164+
func resolveKernelModuleType(gpuDevices []string, searchMap map[string]GPUDevice) (string, error) {
165+
166+
driverHints := getDriverHints(gpuDevices, searchMap)
167+
log.Debugf("driverHints: %v", driverHints)
168+
169+
// NOTE: driver hint "unknown" is assigned to a device that does not have an entry in supported-gpus.json.
170+
// In these cases, we assume that the gpu device is new and unreleased, and we default to OpenRM
171+
requiresOpenRM := slices.Contains(driverHints, DriverHintOpenRequired) || slices.Contains(driverHints, DriverHintUnknown)
172+
requiresProprietary := slices.Contains(driverHints, DriverHintProprietaryRequired)
173+
174+
if requiresOpenRM && requiresProprietary {
175+
return "", fmt.Errorf("unsupported GPU topology")
176+
} else if requiresOpenRM {
177+
return KernelModuleTypeOpen, nil
178+
} else if requiresProprietary {
179+
return KernelModuleTypeProprietary, nil
180+
} else {
181+
return getDriverBranchDefault(driverBranch), nil
182+
}
183+
}
184+
185+
func getDriverHints(gpuDevices []string, searchMap map[string]GPUDevice) []string {
186+
var driverHints []string
187+
188+
for _, gpuDevice := range gpuDevices {
189+
if val, ok := searchMap[gpuDevice]; ok {
190+
gpuFeatures := val.Features
191+
if slices.Contains(gpuFeatures, DriverFeatureKernelGSPProprietary) &&
192+
slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
193+
driverHints = append(driverHints, DriverHintAny)
194+
} else if slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
195+
driverHints = append(driverHints, DriverHintOpenRequired)
196+
} else {
197+
driverHints = append(driverHints, DriverHintProprietaryRequired)
198+
}
199+
} else {
200+
driverHints = append(driverHints, DriverHintUnknown)
201+
}
202+
}
203+
return driverHints
204+
}
205+
206+
func getDriverBranchDefault(driverBranch int) string {
207+
if driverBranch >= 560 {
208+
return KernelModuleTypeOpen
209+
}
210+
return KernelModuleTypeProprietary
211+
}
212+
213+
func buildGPUSearchMap(data GPUData) map[string]GPUDevice {
214+
var gpuMap = make(map[string]GPUDevice)
215+
for _, gpuDevice := range data.Chips {
216+
gpuMap[gpuDevice.ID] = gpuDevice
217+
}
218+
return gpuMap
219+
}
220+
221+
func getNvidiaGPUs() ([]string, error) {
222+
var nvDevices []string
223+
deviceDirs, err := os.ReadDir(PCIDevicesRoot)
224+
if err != nil {
225+
return nil, err
226+
}
227+
228+
for _, device := range deviceDirs {
229+
vendor, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "vendor"))
230+
if err != nil {
231+
return nil, fmt.Errorf("failed to read pci device vendor name for %s: %w", device.Name(), err)
232+
}
233+
if strings.TrimSpace(string(vendor)) != NVIDIAVendorID {
234+
log.Debugf("Skipping device %s as it's not from the NVIDIA vendor", device.Name())
235+
continue
236+
}
237+
class, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "class"))
238+
if err != nil {
239+
return nil, fmt.Errorf("failed to read pci device class name for %s: %w", device.Name(), err)
240+
}
241+
if strings.TrimSpace(string(class)) != PCIDeviceClassVGA && strings.TrimSpace(string(class)) != PCIDeviceClassGPU {
242+
log.Debugf("Skipping NVIDIA device %s as it's not of VGA/GPU device class", device.Name())
243+
continue
244+
}
245+
deviceID, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "device"))
246+
if err != nil {
247+
return nil, fmt.Errorf("failed to read pci device id for %s: %w", device.Name(), err)
248+
}
249+
nvDevices = append(nvDevices, strings.TrimSpace(string(deviceID)))
250+
}
251+
return nvDevices, nil
252+
}

0 commit comments

Comments
 (0)