Skip to content

Commit d879662

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 39beac8 + ffcb1c7 commit d879662

18 files changed

+1340
-94
lines changed

ap.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err er
296296
offset++
297297
}
298298
}
299+
299300
newAP = MakeAP(newShape, newStrides, order, ap.Δ)
300301
}
301302
return

ap_test.go

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,6 @@ import (
77
"github.com/stretchr/testify/assert"
88
)
99

10-
type dummySlice struct {
11-
start, end, step int
12-
}
13-
14-
func (s dummySlice) Start() int { return s.start }
15-
func (s dummySlice) End() int { return s.end }
16-
func (s dummySlice) Step() int { return s.step }
17-
18-
func sli(start int, opt ...int) dummySlice {
19-
var end, step int
20-
switch len(opt) {
21-
case 0:
22-
end = start + 1
23-
step = 0
24-
case 1:
25-
end = opt[0]
26-
step = 1
27-
default:
28-
end = opt[0]
29-
step = opt[1]
30-
31-
}
32-
return dummySlice{start: start, end: end, step: step}
33-
}
34-
3510
func dummyScalar1() AP { return AP{} }
3611

3712
func dummyScalar2() AP { return AP{shape: Shape{1}} }
@@ -203,16 +178,16 @@ var sliceTests = []struct {
203178
contiguous bool
204179
}{
205180
// vectors
206-
{"a[0]", Shape{5}, []Slice{sli(0)}, 0, 1, ScalarShape(), nil, true},
207-
{"a[0:2]", Shape{5}, []Slice{sli(0, 2)}, 0, 2, Shape{2}, []int{1}, true},
208-
{"a[1:3]", Shape{5}, []Slice{sli(1, 3)}, 1, 3, Shape{2}, []int{1}, true},
209-
{"a[1:5:2]", Shape{5}, []Slice{sli(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false},
181+
{"a[0]", Shape{5}, []Slice{S(0)}, 0, 1, ScalarShape(), nil, true},
182+
{"a[0:2]", Shape{5}, []Slice{S(0, 2)}, 0, 2, Shape{2}, []int{1}, true},
183+
{"a[1:3]", Shape{5}, []Slice{S(1, 3)}, 1, 3, Shape{2}, []int{1}, true},
184+
{"a[1:5:2]", Shape{5}, []Slice{S(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false},
210185

211186
// matrix
212-
{"A[0]", Shape{2, 3}, []Slice{sli(0)}, 0, 3, Shape{1, 3}, []int{1}, true},
213-
{"A[1:3]", Shape{4, 5}, []Slice{sli(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true},
214-
{"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{sli(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened
215-
{"A[:, 1:3]", Shape{4, 5}, []Slice{nil, sli(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false},
187+
{"A[0]", Shape{2, 3}, []Slice{S(0)}, 0, 3, Shape{1, 3}, []int{1}, true},
188+
{"A[1:3]", Shape{4, 5}, []Slice{S(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true},
189+
{"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{S(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened
190+
{"A[:, 1:3]", Shape{4, 5}, []Slice{nil, S(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false},
216191

217192
// tensor
218193
{"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true},

api_matop.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ import (
77
// this file handles matops. While by default most of these matops should already have been defined as part of the
88
// Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions
99

10+
// Narrow narrows the tensor.
11+
func Narrow(t Tensor, dim, start, length int) (View, error) {
12+
dim = resolveAxis(dim, t.Dims())
13+
14+
slices := make([]Slice, MinInt(dim+1, t.Dims()))
15+
slices[dim] = S(start, start+length, 1)
16+
17+
return t.Slice(slices...)
18+
}
19+
1020
// Repeat repeats a Tensor along the axis and given the number of repeats.
1121
func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) {
1222
if r, ok := t.Engine().(Repeater); ok {
@@ -135,7 +145,7 @@ func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err
135145
if sbi, ok := a.Engine().(ByIndiceser); ok {
136146
return sbi.SelectByIndices(a, indices, axis, opts...)
137147
}
138-
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
148+
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
139149
}
140150

141151
// ByIndicesB is the backpropagation of ByIndices.
@@ -146,5 +156,41 @@ func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor,
146156
if sbi, ok := a.Engine().(ByIndiceser); ok {
147157
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
148158
}
149-
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
159+
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
160+
}
161+
162+
// LogSoftMax applies log softmax to the given tensor.
163+
func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
164+
if sm, ok := x.Engine().(SoftMaxer); ok {
165+
return sm.LogSoftMax(x, axis, opts...)
166+
}
167+
168+
return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine())
169+
}
170+
171+
// SoftMax applies softmax to the given tensor.
172+
func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
173+
if sm, ok := x.Engine().(SoftMaxer); ok {
174+
return sm.SoftMax(x, axis, opts...)
175+
}
176+
177+
return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine())
178+
}
179+
180+
// SoftMaxB applies softmax backwards operation
181+
func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
182+
if sm, ok := output.Engine().(SoftMaxer); ok {
183+
return sm.SoftMaxB(output, grad, axis, opts...)
184+
}
185+
186+
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
187+
}
188+
189+
// LogSoftMaxB applies softmax backwards operation
190+
func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
191+
if sm, ok := output.Engine().(SoftMaxer); ok {
192+
return sm.LogSoftMaxB(output, grad, axis, opts...)
193+
}
194+
195+
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
150196
}

defaultengine_selbyidx.go

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,27 @@ import (
77
"reflect"
88
)
99

10-
func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
11-
if !b.Shape().IsVectorLike() {
12-
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape())
10+
// SelectByIndices selects the values given the in `indices` tensor.
11+
//
12+
// Currently SelectByIndices only supports Dense tensors that do not require the use of iterators.
13+
// Please make a pull request to support tensors that require the use of an iterator to traverse data.
14+
func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
15+
if !indices.Shape().IsVectorLike() {
16+
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape())
1317
}
14-
if b.Dtype() != Int {
15-
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype())
18+
if indices.Dtype() != Int {
19+
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype())
1620
}
1721

1822
// if b is a scalar, then use Slice
1923
if a.Shape().IsScalarEquiv() {
2024
slices := make([]Slice, a.Shape().Dims())
21-
slices[axis] = ss(b.Data().([]int)[0])
25+
slices[axis] = ss(getInts(indices)[0])
2226
return a.Slice(slices...)
2327
}
2428

2529
expectedShape := a.Shape().Clone()
26-
expectedShape[axis] = b.Shape().TotalSize()
30+
expectedShape[axis] = indices.Shape().TotalSize()
2731

2832
var reuse DenseTensor
2933
var safe, toReuse, _ bool
@@ -36,9 +40,9 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
3640
}
3741

3842
if !safe {
39-
if a.Shape()[axis] != b.Shape().TotalSize() {
43+
if a.Shape()[axis] != indices.Shape().TotalSize() {
4044
expected := a.Shape().Clone()
41-
expected[axis] = b.Shape().TotalSize()
45+
expected[axis] = indices.Shape().TotalSize()
4246
return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape())
4347
}
4448

@@ -49,7 +53,7 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
4953
var dataA, dataB, dataReuse *storage.Header
5054
var ait, bit, iit Iterator
5155
var useIter bool
52-
if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil {
56+
if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil {
5357
return nil, errors.Wrapf(err, "StdEng.Add")
5458
}
5559

@@ -130,39 +134,42 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
130134
}
131135
}
132136

133-
// SelectByIndicesB is the backwards function of SelectByIndices.
134-
func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
137+
// SelectByIndicesB computes the gradient of the result of `SelectByIndices`.
138+
//
139+
// Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators.
140+
// Please make a pull request to support tensors that require the use of an iterator to traverse data.
141+
func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
135142
if !indices.Shape().IsVectorLike() {
136-
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape())
143+
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape())
137144
}
138145
if indices.Dtype() != Int {
139-
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype())
146+
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype())
140147
}
141148

142149
// if b is a scalar, then use Slice
143-
if a.Shape().IsScalarEquiv() {
144-
slices := make([]Slice, a.Shape().Dims())
145-
slices[axis] = ss(b.Data().([]int)[0])
146-
return a.Slice(slices...)
150+
if input.Shape().IsScalarEquiv() {
151+
slices := make([]Slice, input.Shape().Dims())
152+
slices[axis] = ss(outGrad.Data().([]int)[0])
153+
return input.Slice(slices...)
147154
}
148155

149-
expectedShape := a.Shape().Clone()
156+
expectedShape := input.Shape().Clone()
150157

151158
var reuse DenseTensor
152159
var _, toReuse, _ bool
153-
if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
160+
if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil {
154161
return nil, errors.Wrap(err, "Unable to handle funcOpts")
155162
}
156163
if !toReuse && reuse == nil {
157164
// create reuse
158-
reuse = New(WithShape(expectedShape...), Of(a.Dtype()))
165+
reuse = New(WithShape(expectedShape...), Of(input.Dtype()))
159166
}
160167

161-
typ := a.Dtype().Type
168+
typ := input.Dtype().Type
162169
var _, dataB, dataReuse *storage.Header
163170
var _, bit, iit Iterator
164171
var useIter bool
165-
if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil {
172+
if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil {
166173
return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB")
167174
}
168175

@@ -172,7 +179,7 @@ func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt
172179
return
173180
}
174181

175-
e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, b.(*Dense).AP, reuse.(*Dense).AP)
182+
e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP)
176183

177184
return reuse, nil
178185
}
@@ -228,8 +235,8 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data
228235
for i, idx := range indices {
229236
dstCoord[axis] = idx
230237
srcCoord[axis] = i
231-
dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...)
232-
start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...)
238+
dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...)
239+
start, _ := Ltoi(apB.shape, apB.strides, srcCoord...)
233240

234241
for o := 0; o < outer; o++ {
235242
dstEnd := dstStart + axStride

0 commit comments

Comments
 (0)