forked from qiukun/gotorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdevice_test.go
More file actions
38 lines (33 loc) · 776 Bytes
/
device_test.go
File metadata and controls
38 lines (33 loc) · 776 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package gotorch_test
import (
"testing"
"github.com/stretchr/testify/assert"
torch "github.com/wangkuiyi/gotorch"
)
func TestDeviceTo(t *testing.T) {
a := assert.New(t)
a.NotPanics(func() {
var device torch.Device
if torch.IsCUDAAvailable() {
t.Log("CUDA is valid")
device = torch.NewDevice("cuda")
} else {
t.Log("No CUDA found; CPU only")
device = torch.NewDevice("cpu")
}
torch.RandN([]int64{2, 3}, false).To(device, torch.Float)
})
}
func TestDevicePanicWithUnknown(t *testing.T) {
a := assert.New(t)
a.Panics(func() {
torch.NewDevice("unknown")
}, "TestPanicDevice should panics")
}
func TestDeviceIsCUDNNAvailable(t *testing.T) {
if torch.IsCUDNNAvailable() {
t.Log("CUDNN is available")
} else {
t.Log("No CUDNN found")
}
}