11package main
22
33import (
4+ _ "embed"
45 "encoding/json"
56 "fmt"
67 "os"
@@ -23,8 +24,6 @@ const (
2324 PCIDeviceClassVGA = "0x030000"
2425 // PCIDeviceClassGPU represents the pci device class code for GPU devices
2526 PCIDeviceClassGPU = "0x030200"
26- // DefaultSupportedGpusJsonPath represents the default install location of the supported-gpus.json file
27- DefaultSupportedGpusJsonPath = "/usr/share/nvidia-driver-assistant/supported-gpus/supported-gpus.json"
2827
2928 // DriverHintUnknown is used when the gpu device is not found in supported-gpus.json
3029 DriverHintUnknown = "unknown"
5150 driverBranch int
5251)
5352
53+ //go:embed supported-gpus.json
54+ var defaultSupportedGpusJson string
55+
5456type GPUDevice struct {
5557 ID string `json:"devid"`
5658 Name string `json:"name"`
@@ -87,7 +89,6 @@ func main() {
8789 Name : "supported-gpus-file" ,
8890 Aliases : []string {"f" },
8991 Usage : "Specify location of the supported-gpus.json file" ,
90- Value : DefaultSupportedGpusJsonPath ,
9192 Destination : & supportedGpusJsonPath ,
9293 Required : false ,
9394 },
@@ -139,33 +140,36 @@ func GetKernelModule(c *cli.Context) error {
139140 return err
140141 }
141142
142- var gpuData GPUData
143- gpuJSONString , err := os .ReadFile (supportedGpusJsonPath )
144- if err != nil {
145- log .Errorf ("error opening the supported gpus file %s: %v" , supportedGpusJsonPath , err )
146- return err
143+ var jsonData []byte
144+ if len (supportedGpusJsonPath ) > 0 {
145+ jsonData , err = os .ReadFile (supportedGpusJsonPath )
146+ if err != nil {
147+ return fmt .Errorf ("error opening the supported gpus file %s: %w" , supportedGpusJsonPath , err )
148+ }
149+ } else {
150+ jsonData = []byte (defaultSupportedGpusJson )
147151 }
148152
149- err = json .Unmarshal (gpuJSONString , & gpuData )
153+ var gpuData GPUData
154+ err = json .Unmarshal (jsonData , & gpuData )
150155 if err != nil {
151- return err
156+ return fmt . Errorf ( "error unmarshaling the supported gpus json %s: %w" , supportedGpusJsonPath , err )
152157 }
153158
154159 searchMap := buildGPUSearchMap (gpuData )
155160
156161 if len (gpuDevices ) > 0 {
157162 kernelModuleType , err := resolveKernelModuleType (gpuDevices , searchMap )
158163 if err != nil {
159- log .Errorf ("error resolving kernel module type: %v" , err )
160- return err
164+ return fmt .Errorf ("error resolving kernel module type: %w" , err )
161165 }
162166 fmt .Println (kernelModuleType )
163167 }
164168 return nil
165169}
166170
167171func resolveKernelModuleType (gpuDevices []string , searchMap map [string ]GPUDevice ) (string , error ) {
168-
172+ var kernelModuleType string
169173 driverHints := getDriverHints (gpuDevices , searchMap )
170174 log .Debugf ("driverHints: %v" , driverHints )
171175
@@ -177,12 +181,14 @@ func resolveKernelModuleType(gpuDevices []string, searchMap map[string]GPUDevice
177181 if requiresOpenRM && requiresProprietary {
178182 return "" , fmt .Errorf ("unsupported GPU topology" )
179183 } else if requiresOpenRM {
180- return KernelModuleTypeOpen , nil
184+ kernelModuleType = KernelModuleTypeOpen
181185 } else if requiresProprietary {
182- return KernelModuleTypeProprietary , nil
186+ kernelModuleType = KernelModuleTypeProprietary
183187 } else {
184- return getDriverBranchDefault (driverBranch ), nil
188+ kernelModuleType = getDriverBranchDefault (driverBranch )
185189 }
190+ log .Debugf ("printing the recommended kernel module type: %s" , kernelModuleType )
191+ return kernelModuleType , nil
186192}
187193
188194func getDriverHints (gpuDevices []string , searchMap map [string ]GPUDevice ) []string {
0 commit comments