diff --git a/src/impl/pooling_direct.jl b/src/impl/pooling_direct.jl index 78a24ac4..a0e65e34 100644 --- a/src/impl/pooling_direct.jl +++ b/src/impl/pooling_direct.jl @@ -95,14 +95,14 @@ for name in (:max, :mean, :lpnorm) elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] elseif $(name == :lpnorm) - # y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p - m += x[input_kw, input_kh, input_kd, c, batch_idx]^p + # y = (∑ᵢ |xᵢ|^p)^(1 / p), here to calculate ∑ᵢ |xᵢ|^p + m += abs(x[input_kw, input_kh, input_kd, c, batch_idx])^p else error("Unimplemented codegen path") end end - # for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p) + # for lpnormpool, y = (∑ᵢ |xᵢ|^p)^(1 / p) m = $(name == :lpnorm) ? m^(T(1) / p) : m y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx] @@ -151,7 +151,7 @@ for name in (:max, :mean, :lpnorm) elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] elseif $(name == :lpnorm) - m += x[input_kw, input_kh, input_kd, c, batch_idx]^p + m += abs(x[input_kw, input_kh, input_kd, c, batch_idx])^p else error("Unimplemented codegen path") end @@ -263,8 +263,9 @@ for name in (:max, :mean, :lpnorm) # Either does meanpool :( dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha elseif $(name == :lpnorm) - # y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p) - grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) + # y = (∑ᵢ |xᵢ|^p)^(1 / p), ∂y/∂xᵢ = |xᵢ|^(p-1) × y^(1-p) × sign(xᵢ) + xv = x[input_kw, input_kh, input_kd, c, batch_idx] + grad = abs(xv)^(p-1) * y_idx^(1-p) * sign(xv) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path") @@ -327,7 +328,8 @@ for name in (:max, :mean, :lpnorm) elseif $(name == :mean) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...] elseif $(name == :lpnorm) - grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) + xv = x[input_kw, input_kh, input_kd, c, batch_idx] + grad = abs(xv)^(p-1) * y_idx^(1-p) * sign(xv) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path")