Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gpu-driver-util/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module gpu-driver-util

go 1.23.0

require (
github.com/sirupsen/logrus v1.9.3
github.com/urfave/cli/v2 v2.27.5
)

require (
github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
)
24 changes: 24 additions & 0 deletions gpu-driver-util/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc=
github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w=
github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
275 changes: 275 additions & 0 deletions gpu-driver-util/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package main

import (
"encoding/json"
"fmt"
"os"
"path"
"slices"
"strings"

log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)

const (
// LogFile is the path for logging
LogFile = "/var/log/gpu-driver-util.log"
// PCIDevicesRoot represents base path for all pci devices under sysfs
PCIDevicesRoot = "/sys/bus/pci/devices"
// NVIDIAVendorID represents the NVIDIA PCI vendor id
NVIDIAVendorID = "0x10de"
// PCIDeviceClassVGA represents the pci device class code for VGA devices
PCIDeviceClassVGA = "0x030000"
// PCIDeviceClassGPU represents the pci device class code for GPU devices
PCIDeviceClassGPU = "0x030200"
// DefaultSupportedGpusJsonPath represents the default install location of the supported-gpus.json file
DefaultSupportedGpusJsonPath = "/usr/share/nvidia-driver-assistant/supported-gpus/supported-gpus.json"

// DriverHintUnknown is used when the gpu device is not found in supported-gpus.json
DriverHintUnknown = "unknown"
// DriverHintOpenRequired is used when the gpu device compulsorily needs the OpenRM kernel modules
DriverHintOpenRequired = "open-required"
// DriverHintProprietaryRequired is used when the gpu device compulsorily needs the proprietary kernel modules
DriverHintProprietaryRequired = "proprietary-required"
// DriverHintAny is used when the gpu device can support either kernel module types
DriverHintAny = "any-supported"

// DriverFeatureKernelOpen indicates that the gpu device supports OpenRM
DriverFeatureKernelOpen = "kernelopen"
// DriverFeatureKernelGSPProprietary indicates that the gpu device has GSP RM and supports proprietary modules
DriverFeatureKernelGSPProprietary = "gsp_proprietary_supported"

// KernelModuleTypeOpen indicates the OpenRM Kernel Modules of the NVIDIA CUDA driver
KernelModuleTypeOpen = "kernel-open"
// KernelModuleTypeProprietary indicates the Closed/Proprietary Kernel Modules of the NVIDIA CUDA driver
KernelModuleTypeProprietary = "kernel"
)

var (
supportedGpusJsonPath string
driverBranch int
)

type GPUDevice struct {
ID string `json:"devid"`
Name string `json:"name"`
LegacyBranch string `json:"legacybranch"`
Features []string `json:"features"`
}

type GPUData struct {
Chips []GPUDevice `json:"chips"`
}

func main() {
logFile, err := initializeLogger()
if err != nil {
log.Fatal(err.Error())
}
defer logFile.Close()

// Create the top-level CLI app
c := cli.NewApp()
c.Name = "gpu-driver-util"
c.Usage = "NVIDIA GPU Driver Utility Application"
c.Version = "0.1.0"

getKernelModule := cli.Command{}
getKernelModule.Name = "get-kernel-module-type"
getKernelModule.Usage = "Automatically determine the kernel module type based on the GPUs detected."
getKernelModule.Action = func(c *cli.Context) error {
return GetKernelModule(c)
}

getKernelModuleFlags := []cli.Flag{
&cli.StringFlag{
Name: "supported-gpus-file",
Aliases: []string{"f"},
Usage: "Specify location of the supported-gpus.json file",
Value: DefaultSupportedGpusJsonPath,
Destination: &supportedGpusJsonPath,
Required: false,
},
&cli.IntFlag{
Name: "driver-branch",
Aliases: []string{"b"},
Usage: "Specify driver branch",
EnvVars: []string{"DRIVER_BRANCH"},
Destination: &driverBranch,
Required: true,
},
}

c.Commands = []*cli.Command{
&getKernelModule,
}

getKernelModule.Flags = append([]cli.Flag{}, getKernelModuleFlags...)

// Run the top-level CLI
if err := c.Run(os.Args); err != nil {
log.Fatal(fmt.Errorf("error running gpu-driver-util: %w", err))
}

}

func initializeLogger() (*os.File, error) {
logFile, err := os.OpenFile(LogFile, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("error opening file %s: %w", LogFile, err)
}
// Log as JSON instead of the default ASCII formatter.
log.SetFormatter(&log.JSONFormatter{})
// Output to file instead of stdout
log.SetOutput(logFile)

// Only log the warning severity or above.
log.SetLevel(log.DebugLevel)
return logFile, nil
}

func GetKernelModule(c *cli.Context) error {
log.Infof("Starting the 'get-kernel-module' command of %s", c.App.Name)
gpuDevices, err := getNvidiaGPUs()

log.Debugf("NVIDIA GPU devices found: %v", gpuDevices)

if err != nil {
return err
}

var gpuData GPUData
gpuJSONString, err := os.ReadFile(supportedGpusJsonPath)
if err != nil {
log.Errorf("error opening the supported gpus file %s: %v", supportedGpusJsonPath, err)
return err
}

err = json.Unmarshal(gpuJSONString, &gpuData)
if err != nil {
return err
}

searchMap := buildGPUSearchMap(gpuData)

if len(gpuDevices) > 0 {
kernelModuleType, err := resolveKernelModuleType(gpuDevices, searchMap)
if err != nil {
log.Errorf("error resolving kernel module type: %v", err)
return err
}
fmt.Println(kernelModuleType)
}
return nil
}

func resolveKernelModuleType(gpuDevices []string, searchMap map[string]GPUDevice) (string, error) {

driverHints := getDriverHints(gpuDevices, searchMap)
log.Debugf("driverHints: %v", driverHints)

// NOTE: driver hint "unknown" is assigned to a device that does not have an entry in supported-gpus.json.
// In these cases, we assume that the gpu device is new and unreleased, and we default to OpenRM
requiresOpenRM := slices.Contains(driverHints, DriverHintOpenRequired) || slices.Contains(driverHints, DriverHintUnknown)
requiresProprietary := slices.Contains(driverHints, DriverHintProprietaryRequired)

if requiresOpenRM && requiresProprietary {
return "", fmt.Errorf("unsupported GPU topology")
} else if requiresOpenRM {
return KernelModuleTypeOpen, nil
} else if requiresProprietary {
return KernelModuleTypeProprietary, nil
} else {
return getDriverBranchDefault(driverBranch), nil
}
}

func getDriverHints(gpuDevices []string, searchMap map[string]GPUDevice) []string {
var driverHints []string

for _, gpuDevice := range gpuDevices {
if val, ok := searchMap[gpuDevice]; ok {
gpuFeatures := val.Features
if slices.Contains(gpuFeatures, DriverFeatureKernelGSPProprietary) &&
slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
driverHints = append(driverHints, DriverHintAny)
} else if slices.Contains(gpuFeatures, DriverFeatureKernelOpen) {
driverHints = append(driverHints, DriverHintOpenRequired)
} else {
driverHints = append(driverHints, DriverHintProprietaryRequired)
}
} else {
driverHints = append(driverHints, DriverHintUnknown)
}
}
return driverHints
}

func getDriverBranchDefault(driverBranch int) string {
if driverBranch >= 560 {
return KernelModuleTypeOpen
}
return KernelModuleTypeProprietary
}

func buildGPUSearchMap(data GPUData) map[string]GPUDevice {
var gpuMap = make(map[string]GPUDevice)
for _, gpuDevice := range data.Chips {
gpuMap[gpuDevice.ID] = gpuDevice
}
return gpuMap
}

func getNvidiaGPUs() ([]string, error) {
var nvDevices []string
deviceDirs, err := os.ReadDir(PCIDevicesRoot)
if err != nil {
return nil, err
}

for _, device := range deviceDirs {
vendor, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "vendor"))
if err != nil {
return nil, fmt.Errorf("failed to read pci device vendor name for %s: %w", device.Name(), err)
}
if strings.TrimSpace(string(vendor)) != NVIDIAVendorID {
log.Tracef("Skipping device %s as it's not from the NVIDIA vendor", device.Name())
continue
}
class, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "class"))
if err != nil {
return nil, fmt.Errorf("failed to read pci device class name for %s: %w", device.Name(), err)
}
if strings.TrimSpace(string(class)) != PCIDeviceClassVGA && strings.TrimSpace(string(class)) != PCIDeviceClassGPU {
log.Tracef("Skipping NVIDIA device %s as it's not of VGA/GPU device class", device.Name())
continue
}
b, err := os.ReadFile(path.Join(PCIDevicesRoot, device.Name(), "device"))
if err != nil {
return nil, fmt.Errorf("failed to read pci device id for %s: %w", device.Name(), err)
}

deviceID, err := sanitizeDeviceID(string(b))
if err != nil {
return nil, fmt.Errorf("found invalid device id for %s: %w", device.Name(), err)
}

nvDevices = append(nvDevices, deviceID)
}
return nvDevices, nil
}

func sanitizeDeviceID(input string) (string, error) {
var result string
result = strings.TrimSpace(input)

if len(result) != 6 {
return "", fmt.Errorf("invalid device id format: %s", input)
}

// We only uppercase the device after the 0x part of the device id string to match the format in supported-gpus.json
// For e.g. "0x1db6" becomes "0x1DB6"
result = fmt.Sprintf("%s%s", result[0:2], strings.ToUpper(result[2:]))
return result, nil
}
Loading