Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 64e1dbe

Browse files
gauridesdiyessi
authored andcommitted
Use all args for dropout (#3069)
1 parent 33c7413 commit 64e1dbe

File tree

4 files changed

+51
-30
lines changed

4 files changed

+51
-30
lines changed

src/ngraph/runtime/cpu/builder/dropout.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ namespace ngraph
3838

3939
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
4040
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
41+
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name());
4142
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
4243
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
4344

4445
size_t element_count = out[0].get_size();
4546

4647
bool use_seed = drop->get_use_seed();
47-
double keep_prob = drop->get_keep_prob();
4848

4949
// Note: for performance optimization in addition to parallel RNG with multiple,
5050
// threads, we create, initialize and advance each msr here in builder instead of
@@ -56,7 +56,7 @@ namespace ngraph
5656
std::vector<std::minstd_rand> vmsr(nthr);
5757
if (use_seed)
5858
{
59-
uint32_t seed = drop->get_seed();
59+
uint64_t seed = drop->get_seed();
6060
for (size_t i = 0; i < nthr; i++)
6161
{
6262
std::minstd_rand msr;
@@ -72,13 +72,15 @@ namespace ngraph
7272
element_count,
7373
arg_buffer_index,
7474
arg1_buffer_index,
75+
arg4_buffer_index,
7576
out0_buffer_index,
7677
out1_buffer_index,
77-
keep_prob,
7878
vmsr,
7979
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
8080
bool training = static_cast<bool>(
8181
static_cast<float*>(ctx->buffer_data[arg1_buffer_index])[0]);
82+
double keep_prob =
83+
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
8284
runtime::cpu::kernel::generate_dropout(
8385
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
8486
static_cast<float*>(ctx->buffer_data[out0_buffer_index]),
@@ -96,13 +98,15 @@ namespace ngraph
9698
element_count,
9799
arg_buffer_index,
98100
arg1_buffer_index,
101+
arg4_buffer_index,
99102
out0_buffer_index,
100103
out1_buffer_index,
101-
keep_prob,
102104
vmsr,
103105
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
104106
bool training = static_cast<bool>(
105107
static_cast<double*>(ctx->buffer_data[arg1_buffer_index])[0]);
108+
double keep_prob =
109+
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
106110
runtime::cpu::kernel::generate_dropout(
107111
static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
108112
static_cast<double*>(ctx->buffer_data[out0_buffer_index]),

src/ngraph/runtime/cpu/op/dropout.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@ using namespace ngraph;
2626
op::Dropout::Dropout(const std::shared_ptr<Node>& input,
2727
const std::shared_ptr<Node>& gm_const,
2828
const std::shared_ptr<Node>& use_seed,
29-
const uint32_t seed,
30-
const double keep_prob)
31-
: Op("Dropout", check_single_output_args({input, gm_const, use_seed}))
32-
, m_seed(seed)
33-
, m_keep_prob(keep_prob)
29+
const std::shared_ptr<Node>& seed,
30+
const std::shared_ptr<Node>& keep_prob)
31+
: Op("Dropout", check_single_output_args({input, gm_const, use_seed, seed, keep_prob}))
3432
{
3533
constructor_validate_and_infer_types();
3634

@@ -41,13 +39,13 @@ op::Dropout::Dropout(const std::shared_ptr<Node>& input,
4139

4240
shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) const
4341
{
44-
if (new_args.size() != 3)
42+
if (new_args.size() != 5)
4543
{
4644
throw ngraph_error("Incorrect number of new arguments");
4745
}
4846

4947
return make_shared<Dropout>(
50-
new_args.at(0), new_args.at(1), new_args.at(2), m_seed, m_keep_prob);
48+
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
5149
}
5250

5351
bool op::Dropout::get_use_seed() const
@@ -60,3 +58,14 @@ bool op::Dropout::get_use_seed() const
6058
}
6159
return use_seed;
6260
}
61+
62+
uint64_t op::Dropout::get_seed() const
63+
{
64+
uint64_t seed = 0;
65+
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(3)))
66+
{
67+
auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr());
68+
seed = *seed_ptr;
69+
}
70+
return seed;
71+
}

src/ngraph/runtime/cpu/op/dropout.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,15 @@ namespace ngraph
2929
Dropout(const std::shared_ptr<Node>& input,
3030
const std::shared_ptr<Node>& gm_const,
3131
const std::shared_ptr<Node>& use_seed,
32-
const uint32_t seed,
33-
const double keep_prob); // keep_prob = 1 - dropout_prob
32+
const std::shared_ptr<Node>& seed,
33+
const std::shared_ptr<Node>& keep_prob); // keep_prob = 1 - dropout_prob
3434

3535
bool get_use_seed() const;
36-
uint32_t get_seed() const { return m_seed; }
37-
double get_keep_prob() const { return m_keep_prob; }
38-
void set_seed(uint32_t new_seed) { m_seed = new_seed; }
39-
void set_keep_prob(double new_keep_prob) { m_keep_prob = new_keep_prob; }
36+
uint64_t get_seed() const;
37+
double get_keep_prob() const;
38+
4039
virtual std::shared_ptr<Node>
4140
copy_with_new_args(const NodeVector& new_args) const override;
42-
43-
private:
44-
uint32_t m_seed;
45-
double m_keep_prob;
4641
};
4742
}
4843
}

src/ngraph/runtime/cpu/pass/cpu_fusion.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
923923
auto x = std::make_shared<pattern::op::Label>(element::f32, shape);
924924
auto x_label = std::make_shared<pattern::op::Label>(x, nullptr, NodeVector{x});
925925

926-
uint32_t seed = 1234;
927-
auto seed_label = std::make_shared<pattern::op::Label>(element::u32, Shape{0});
926+
uint64_t seed = 1234;
927+
auto seed_label = std::make_shared<pattern::op::Label>(element::u64, Shape{0});
928928

929929
double value = 0.9;
930930
auto value_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {value});
@@ -960,15 +960,28 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
960960
NGRAPH_DEBUG << "training argument to GenerateMask must be constant";
961961
return false;
962962
}
963+
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(2)))
964+
{
965+
NGRAPH_DEBUG << "use_seed argument to GenerateMask must be constant";
966+
return false;
967+
}
968+
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(3)))
969+
{
970+
NGRAPH_DEBUG << "seed argument to GenerateMask must be constant";
971+
return false;
972+
}
973+
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(4)))
974+
{
975+
NGRAPH_DEBUG << "probability argument to GenerateMask must be constant";
976+
return false;
977+
}
963978

964-
auto gm_value = gm->get_probability();
965-
auto gm_seed = gm->get_seed();
966-
967-
auto training = gm->get_argument(0); //for training purpose this is always going to be 1
968-
auto use_seed_arg = gm->get_argument(2); // this is the use_seed node
979+
auto dropout_n = std::make_shared<ngraph::op::Dropout>(pattern_map[x],
980+
gm->get_argument(0),
981+
gm->get_argument(2),
982+
gm->get_argument(3),
983+
gm->get_argument(4));
969984

970-
auto dropout_n = std::make_shared<ngraph::op::Dropout>(
971-
pattern_map[x], training, use_seed_arg, gm_seed, gm_value);
972985
auto goe1 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 0);
973986
ngraph::replace_node(m.get_match_root(), goe1);
974987

0 commit comments

Comments
 (0)