|
19 | 19 | #include "ngraph/runtime/cpu/kernel/softmax.hpp" |
20 | 20 | #include "ngraph/runtime/cpu/mkldnn_invoke.hpp" |
21 | 21 | #include "ngraph/runtime/cpu/mkldnn_utils.hpp" |
| 22 | +#include "ngraph/runtime/reference/softmax.hpp" |
22 | 23 |
|
23 | 24 | using namespace std; |
24 | 25 | using namespace ngraph; |
@@ -131,8 +132,35 @@ namespace ngraph |
131 | 132 | }; |
132 | 133 | functors.emplace_back(functor); |
133 | 134 | } |
| 135 | + else if (arg_shape.size() == 4 && axes.size() == 3) |
| 136 | + { |
| 137 | + std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel; |
| 138 | + |
| 139 | + SELECT_KERNEL(kernel, |
| 140 | + args[0].get_element_type(), |
| 141 | + runtime::cpu::kernel::softmax_4d_3rd); |
| 142 | + |
| 143 | + auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) { |
| 144 | + kernel(arg_tensor, out_tensor, arg_shape, axes); |
| 145 | + }; |
| 146 | + functors.emplace_back(functor); |
| 147 | + } |
| 148 | + else if (softmax->get_element_type() == element::f32) |
| 149 | + { |
| 150 | + NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape |
| 151 | + << " over " << axes; |
| 152 | + auto functor = [&, arg_shape, axes](CPURuntimeContext* ctx) { |
| 153 | + runtime::reference::softmax<float>(static_cast<float*>(arg_tensor), |
| 154 | + static_cast<float*>(out_tensor), |
| 155 | + arg_shape, |
| 156 | + axes); |
| 157 | + }; |
| 158 | + functors.emplace_back(functor); |
| 159 | + } |
134 | 160 | else |
135 | 161 | { |
| 162 | + NGRAPH_ERR << "Unsupported Softmax " << arg_shape << " over " << axes |
| 163 | + << " in cpu buiilder"; |
136 | 164 | throw ngraph_error("Unsupported Softmax"); |
137 | 165 | } |
138 | 166 | } |
|
0 commit comments