@@ -6,121 +6,108 @@ import (
66 "github.com/stretchr/testify/assert"
77)
88
9- func TestDense_SelectByIndices (t * testing.T ) {
10- assert := assert .New (t )
11-
12- a := New (WithBacking ([]float64 {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 }), WithShape (3 , 2 , 4 ))
13- indices := New (WithBacking ([]int {1 , 1 }))
14-
15- e := StdEng {}
16-
17- a1 , err := e .SelectByIndices (a , indices , 1 )
18- if err != nil {
19- t .Errorf ("%v" , err )
20- }
21- correct1 := []float64 {4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 20 , 21 , 22 , 23 , 20 , 21 , 22 , 23 }
22- assert .Equal (correct1 , a1 .Data ())
23-
24- a0 , err := e .SelectByIndices (a , indices , 0 )
25- if err != nil {
26- t .Errorf ("%v" , err )
27- }
28- correct0 := []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
29- assert .Equal (correct0 , a0 .Data ())
9+ type selByIndicesTest struct {
10+ Name string
11+ Data interface {}
12+ Shape Shape
13+ Indices []int
14+ Axis int
15+ WillErr bool
16+
17+ Correct interface {}
18+ CorrectShape Shape
19+ }
3020
31- a2 , err := e .SelectByIndices (a , indices , 2 )
32- if err != nil {
33- t .Errorf ("%v" , err )
34- }
35- correct2 := []float64 {1 , 1 , 5 , 5 , 9 , 9 , 13 , 13 , 17 , 17 , 21 , 21 }
36- assert .Equal (correct2 , a2 .Data ())
21+ var selByIndicesTests = []selByIndicesTest {
22+ {Name : "3-tensor, axis 0" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 0 , WillErr : false ,
23+ Correct : []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }, CorrectShape : Shape {2 , 2 , 4 }},
3724
38- // !safe
39- aUnsafe := a .Clone ().(* Dense )
40- indices = New (WithBacking ([]int {1 , 1 , 1 }))
41- aUnsafeSelect , err := e .SelectByIndices (aUnsafe , indices , 0 , UseUnsafe ())
42- if err != nil {
43- t .Errorf ("%v" , err )
44- }
45- correctUnsafe := []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
46- assert .Equal (correctUnsafe , aUnsafeSelect .Data ())
25+ {Name : "3-tensor, axis 1" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 1 , WillErr : false ,
26+ Correct : []float64 {4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 20 , 21 , 22 , 23 , 20 , 21 , 22 , 23 }, CorrectShape : Shape {3 , 2 , 4 }},
4727
48- // 3 indices, just to make sure the sanity of the algorithm
49- indices = New (WithBacking ([]int {1 , 1 , 1 }))
50- a1 , err = e .SelectByIndices (a , indices , 1 )
51- if err != nil {
52- t .Errorf ("%v" , err )
53- }
54- correct1 = []float64 {
55- 4 , 5 , 6 , 7 ,
56- 4 , 5 , 6 , 7 ,
57- 4 , 5 , 6 , 7 ,
28+ {Name : "3-tensor, axis 2" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 2 , WillErr : false ,
29+ Correct : []float64 {1 , 1 , 5 , 5 , 9 , 9 , 13 , 13 , 17 , 17 , 21 , 21 }, CorrectShape : Shape {3 , 2 , 2 }},
5830
59- 12 , 13 , 14 , 15 ,
60- 12 , 13 , 14 , 15 ,
61- 12 , 13 , 14 , 15 ,
31+ {Name : "Vector, axis 0" , Data : Range (Int , 0 , 5 ), Shape : Shape {5 }, Indices : []int {1 , 1 }, Axis : 0 , WillErr : false ,
32+ Correct : []int {1 , 1 }, CorrectShape : Shape {2 }},
6233
63- 20 , 21 , 22 , 23 ,
64- 20 , 21 , 22 , 23 ,
65- 20 , 21 , 22 , 23 ,
66- }
67- assert .Equal (correct1 , a1 .Data ())
34+ {Name : "Vector, axis 1" , Data : Range (Int , 0 , 5 ), Shape : Shape {5 }, Indices : []int {1 , 1 }, Axis : 1 , WillErr : true ,
35+ Correct : []int {1 , 1 }, CorrectShape : Shape {2 }},
36+ {Name : "(4,2) Matrix, with (10) indices" , Data : Range (Float32 , 0 , 8 ), Shape : Shape {4 , 2 }, Indices : []int {1 , 1 , 1 , 1 , 0 , 2 , 2 , 2 , 2 , 0 }, Axis : 0 , WillErr : false ,
37+ Correct : []float32 {2 , 3 , 2 , 3 , 2 , 3 , 2 , 3 , 0 , 1 , 4 , 5 , 4 , 5 , 4 , 5 , 4 , 5 , 0 , 1 }, CorrectShape : Shape {10 , 2 }},
38+ {Name : "(2,1) Matrx (colvec)m with (10) indies" , Data : Range (Float64 , 0 , 2 ), Shape : Shape {2 , 1 }, Indices : []int {1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 }, Axis : 0 , WillErr : false ,
39+ Correct : []float64 {1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 }, CorrectShape : Shape {10 },
40+ },
41+ }
6842
69- a0 , err = e .SelectByIndices (a , indices , 0 )
70- if err != nil {
71- t .Errorf ("%v" , err )
43+ func TestDense_SelectByIndices (t * testing.T ) {
44+ assert := assert .New (t )
45+ for i , tc := range selByIndicesTests {
46+ T := New (WithShape (tc .Shape ... ), WithBacking (tc .Data ))
47+ indices := New (WithBacking (tc .Indices ))
48+ ret , err := ByIndices (T , indices , tc .Axis )
49+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
50+ continue
51+ }
52+ assert .Equal (tc .Correct , ret .Data ())
53+ assert .True (tc .CorrectShape .Eq (ret .Shape ()))
7254 }
73- correct0 = []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
74- assert .Equal (correct0 , a0 .Data ())
55+ }
7556
76- a2 , err = e .SelectByIndices (a , indices , 2 )
77- if err != nil {
78- t .Errorf ("%v" , err )
79- }
80- correct2 = []float64 {1 , 1 , 1 , 5 , 5 , 5 , 9 , 9 , 9 , 13 , 13 , 13 , 17 , 17 , 17 , 21 , 21 , 21 }
81- assert .Equal (correct2 , a2 .Data ())
57+ var selByIndicesBTests = []struct {
58+ selByIndicesTest
59+
60+ CorrectGrad interface {}
61+ CorrectGradShape Shape
62+ }{
63+ {
64+ selByIndicesTest : selByIndicesTests [0 ],
65+ CorrectGrad : []float64 {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 },
66+ CorrectGradShape : Shape {3 , 2 , 4 },
67+ },
68+ {
69+ selByIndicesTest : selByIndicesTests [1 ],
70+ CorrectGrad : []float64 {0 , 0 , 0 , 0 , 8 , 10 , 12 , 14 , 0 , 0 , 0 , 0 , 24 , 26 , 28 , 30 , 0 , 0 , 0 , 0 , 40 , 42 , 44 , 46 },
71+ CorrectGradShape : Shape {3 , 2 , 4 },
72+ },
73+ {
74+ selByIndicesTest : selByIndicesTests [2 ],
75+ CorrectGrad : []float64 {0 , 2 , 0 , 0 , 0 , 10 , 0 , 0 , 0 , 18 , 0 , 0 , 0 , 26 , 0 , 0 , 0 , 34 , 0 , 0 , 0 , 42 , 0 , 0 },
76+ CorrectGradShape : Shape {3 , 2 , 4 },
77+ },
78+ {
79+ selByIndicesTest : selByIndicesTests [3 ],
80+ CorrectGrad : []int {0 , 2 , 0 , 0 , 0 },
81+ CorrectGradShape : Shape {5 },
82+ },
83+ {
84+ selByIndicesTest : selByIndicesTests [5 ],
85+ CorrectGrad : []float32 {4 , 6 , 8 , 12 , 8 , 12 , 0 , 0 },
86+ CorrectGradShape : Shape {4 , 2 },
87+ },
88+ {
89+ selByIndicesTest : selByIndicesTests [6 ],
90+ CorrectGrad : []float64 {0 , 10 },
91+ CorrectGradShape : Shape {2 , 1 },
92+ },
8293}
8394
8495func TestDense_SelectByIndicesB (t * testing.T ) {
8596
86- a := New (WithBacking ([]float64 {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 }), WithShape (3 , 2 , 4 ))
87- indices := New (WithBacking ([]int {1 , 1 }))
88-
89- t .Logf ("a\n %v" , a )
90-
91- e := StdEng {}
92-
93- a1 , err := e .SelectByIndices (a , indices , 1 )
94- if err != nil {
95- t .Errorf ("%v" , err )
96- }
97- t .Logf ("a1\n %v" , a1 )
98-
99- a1Grad , err := e .SelectByIndicesB (a , a1 , indices , 1 )
100- if err != nil {
101- t .Errorf ("%v" , err )
102- }
103- t .Logf ("a1Grad \n %v" , a1Grad )
104-
105- a0 , err := e .SelectByIndices (a , indices , 0 )
106- if err != nil {
107- t .Errorf ("%v" , err )
108- }
109- t .Logf ("a0\n %v" , a0 )
110- a0Grad , err := e .SelectByIndicesB (a , a0 , indices , 0 )
111- if err != nil {
112- t .Errorf ("%v" , err )
97+ assert := assert .New (t )
98+ for i , tc := range selByIndicesBTests {
99+ T := New (WithShape (tc .Shape ... ), WithBacking (tc .Data ))
100+ indices := New (WithBacking (tc .Indices ))
101+ ret , err := ByIndices (T , indices , tc .Axis )
102+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
103+ continue
104+ }
105+ grad , err := ByIndicesB (T , ret , indices , tc .Axis )
106+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
107+ continue
108+ }
109+ assert .Equal (tc .CorrectGrad , grad .Data (), "%v" , tc .Name )
110+ assert .True (tc .CorrectGradShape .Eq (grad .Shape ()), "%v - Grad shape should be %v. Got %v instead" , tc .Name , tc .CorrectGradShape , grad .Shape ())
113111 }
114- t .Logf ("a0Grad\n %v" , a0Grad )
115112
116- a2 , err := e .SelectByIndices (a , indices , 2 )
117- if err != nil {
118- t .Errorf ("%v" , err )
119- }
120- t .Logf ("\n %v" , a2 )
121- a2Grad , err := e .SelectByIndicesB (a , a2 , indices , 2 )
122- if err != nil {
123- t .Errorf ("%v" , err )
124- }
125- t .Logf ("a2Grad\n %v" , a2Grad )
126113}
0 commit comments