11package main
22
33import (
4- _ "embed"
54 "encoding/json"
65 "fmt"
76 "os"
@@ -24,6 +23,8 @@ const (
2423 PCIDeviceClassVGA = "0x030000"
2524 // PCIDeviceClassGPU represents the pci device class code for GPU devices
2625 PCIDeviceClassGPU = "0x030200"
26+ // DefaultSupportedGpusJsonPath represents the default install location of the supported-gpus.json file
27+ DefaultSupportedGpusJsonPath = "/usr/share/nvidia-driver-assistant/supported-gpus/supported-gpus.json"
2728
2829 // DriverHintUnknown is used when the gpu device is not found in supported-gpus.json
2930 DriverHintUnknown = "unknown"
5051 driverBranch int
5152)
5253
53- //go:embed supported-gpus.json
54- var defaultSupportedGpusJson string
55-
5654type GPUDevice struct {
5755 ID string `json:"devid"`
5856 Name string `json:"name"`
@@ -89,6 +87,7 @@ func main() {
8987 Name : "supported-gpus-file" ,
9088 Aliases : []string {"f" },
9189 Usage : "Specify location of the supported-gpus.json file" ,
90+ Value : DefaultSupportedGpusJsonPath ,
9291 Destination : & supportedGpusJsonPath ,
9392 Required : false ,
9493 },
@@ -140,36 +139,33 @@ func GetKernelModule(c *cli.Context) error {
140139 return err
141140 }
142141
143- var jsonData []byte
144- if len (supportedGpusJsonPath ) > 0 {
145- jsonData , err = os .ReadFile (supportedGpusJsonPath )
146- if err != nil {
147- return fmt .Errorf ("error opening the supported gpus file %s: %w" , supportedGpusJsonPath , err )
148- }
149- } else {
150- jsonData = []byte (defaultSupportedGpusJson )
142+ var gpuData GPUData
143+ gpuJSONString , err := os .ReadFile (supportedGpusJsonPath )
144+ if err != nil {
145+ log .Errorf ("error opening the supported gpus file %s: %v" , supportedGpusJsonPath , err )
146+ return err
151147 }
152148
153- var gpuData GPUData
154- err = json .Unmarshal (jsonData , & gpuData )
149+ err = json .Unmarshal (gpuJSONString , & gpuData )
155150 if err != nil {
156- return fmt . Errorf ( "error unmarshaling the supported gpus json %s: %w" , supportedGpusJsonPath , err )
151+ return err
157152 }
158153
159154 searchMap := buildGPUSearchMap (gpuData )
160155
161156 if len (gpuDevices ) > 0 {
162157 kernelModuleType , err := resolveKernelModuleType (gpuDevices , searchMap )
163158 if err != nil {
164- return fmt .Errorf ("error resolving kernel module type: %w" , err )
159+ log .Errorf ("error resolving kernel module type: %v" , err )
160+ return err
165161 }
166162 fmt .Println (kernelModuleType )
167163 }
168164 return nil
169165}
170166
171167func resolveKernelModuleType (gpuDevices []string , searchMap map [string ]GPUDevice ) (string , error ) {
172- var kernelModuleType string
168+
173169 driverHints := getDriverHints (gpuDevices , searchMap )
174170 log .Debugf ("driverHints: %v" , driverHints )
175171
@@ -181,14 +177,12 @@ func resolveKernelModuleType(gpuDevices []string, searchMap map[string]GPUDevice
181177 if requiresOpenRM && requiresProprietary {
182178 return "" , fmt .Errorf ("unsupported GPU topology" )
183179 } else if requiresOpenRM {
184- kernelModuleType = KernelModuleTypeOpen
180+ return KernelModuleTypeOpen , nil
185181 } else if requiresProprietary {
186- kernelModuleType = KernelModuleTypeProprietary
182+ return KernelModuleTypeProprietary , nil
187183 } else {
188- kernelModuleType = getDriverBranchDefault (driverBranch )
184+ return getDriverBranchDefault (driverBranch ), nil
189185 }
190- log .Debugf ("printing the recommended kernel module type: %s" , kernelModuleType )
191- return kernelModuleType , nil
192186}
193187
194188func getDriverHints (gpuDevices []string , searchMap map [string ]GPUDevice ) []string {
0 commit comments