@@ -2,8 +2,9 @@ package main
22
33import (
44 "path/filepath"
5- "reflect"
65 "testing"
6+
7+ "github.com/stretchr/testify/require"
78)
89
910func TestGetNvidiaConfig (t * testing.T ) {
@@ -414,39 +415,28 @@ func TestGetNvidiaConfig(t *testing.T) {
414415
415416 // For any tests that are expected to panic, make sure they do.
416417 if tc .expectedPanic {
417- mustPanic (t , getConfig )
418+ require . Panics (t , getConfig )
418419 return
419420 }
420421
421422 // For all other tests, just grab the config
422423 getConfig ()
423424
424425 // And start comparing the test results to the expected results.
425- if config == nil && tc .expectedConfig == nil {
426+ if tc .expectedConfig == nil {
427+ require .Nil (t , config , tc .description )
426428 return
427429 }
428- if config != nil && tc .expectedConfig != nil {
429- if ! reflect .DeepEqual (config .Devices , tc .expectedConfig .Devices ) {
430- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
431- }
432- if ! reflect .DeepEqual (config .MigConfigDevices , tc .expectedConfig .MigConfigDevices ) {
433- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
434- }
435- if ! reflect .DeepEqual (config .MigMonitorDevices , tc .expectedConfig .MigMonitorDevices ) {
436- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
437- }
438- if ! reflect .DeepEqual (config .DriverCapabilities , tc .expectedConfig .DriverCapabilities ) {
439- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
440- }
441- if ! elementsMatch (config .Requirements , tc .expectedConfig .Requirements ) {
442- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
443- }
444- if ! reflect .DeepEqual (config .DisableRequire , tc .expectedConfig .DisableRequire ) {
445- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
446- }
447- return
448- }
449- t .Errorf ("Unexpected nvidiaConfig (got: %v, wanted: %v)" , config , tc .expectedConfig )
430+
431+ require .NotNil (t , config , tc .description )
432+
433+ require .Equal (t , tc .expectedConfig .Devices , config .Devices )
434+ require .Equal (t , tc .expectedConfig .MigConfigDevices , config .MigConfigDevices )
435+ require .Equal (t , tc .expectedConfig .MigMonitorDevices , config .MigMonitorDevices )
436+ require .Equal (t , tc .expectedConfig .DriverCapabilities , config .DriverCapabilities )
437+
438+ require .ElementsMatch (t , tc .expectedConfig .Requirements , config .Requirements )
439+ require .Equal (t , tc .expectedConfig .DisableRequire , config .DisableRequire )
450440 })
451441 }
452442}
@@ -524,9 +514,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
524514 for _ , tc := range tests {
525515 t .Run (tc .description , func (t * testing.T ) {
526516 devices := getDevicesFromMounts (tc .mounts )
527- if ! reflect .DeepEqual (devices , tc .expectedDevices ) {
528- t .Errorf ("Unexpected devices (got: %v, wanted: %v)" , * devices , * tc .expectedDevices )
529- }
517+ require .Equal (t , tc .expectedDevices , devices )
530518 })
531519 }
532520}
@@ -639,36 +627,8 @@ func TestDeviceListSourcePriority(t *testing.T) {
639627
640628 // For all other tests, just grab the devices and check the results
641629 getDevices ()
642- if ! reflect .DeepEqual (devices , tc .expectedDevices ) {
643- t .Errorf ("Unexpected devices (got: %v, wanted: %v)" , * devices , * tc .expectedDevices )
644- }
645- })
646- }
647- }
648-
649- func elementsMatch (slice0 , slice1 []string ) bool {
650- map0 := make (map [string ]int )
651- map1 := make (map [string ]int )
652-
653- for _ , e := range slice0 {
654- map0 [e ]++
655- }
656630
657- for _ , e := range slice1 {
658- map1 [e ]++
659- }
660-
661- for k0 , v0 := range map0 {
662- if map1 [k0 ] != v0 {
663- return false
664- }
665- }
666-
667- for k1 , v1 := range map1 {
668- if map0 [k1 ] != v1 {
669- return false
670- }
631+ require .Equal (t , tc .expectedDevices , devices )
632+ })
671633 }
672-
673- return true
674634}
0 commit comments