Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 6db5620

Browse files
vdevaramAdam Procter
authored andcommitted
Constant folding with Quantize (#1833)
* Constant folding with Quantize * updated with review comments
1 parent 27530c6 commit 6db5620

File tree

3 files changed

+120
-1
lines changed

3 files changed

+120
-1
lines changed

src/ngraph/pass/constant_folding.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "ngraph/op/multiply.hpp"
3030
#include "ngraph/op/negative.hpp"
3131
#include "ngraph/op/pad.hpp"
32+
#include "ngraph/op/quantize.hpp"
3233
#include "ngraph/op/reshape.hpp"
3334
#include "ngraph/op/subtract.hpp"
3435
#include "ngraph/pattern/matcher.hpp"
@@ -43,6 +44,7 @@
4344
#include "ngraph/runtime/reference/multiply.hpp"
4445
#include "ngraph/runtime/reference/negate.hpp"
4546
#include "ngraph/runtime/reference/pad.hpp"
47+
#include "ngraph/runtime/reference/quantize.hpp"
4648
#include "ngraph/runtime/reference/reshape.hpp"
4749
#include "ngraph/runtime/reference/subtract.hpp"
4850

@@ -529,3 +531,81 @@ void ngraph::pass::ConstantFolding::construct_constant_dequantize()
529531
auto dequantize_matcher = make_shared<pattern::Matcher>(dequant, constant_dequantize_callback);
530532
this->add_matcher(dequantize_matcher);
531533
}
534+
535+
template <class REAL, class QUANT>
536+
shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constant,
537+
shared_ptr<op::Quantize> quant,
538+
shared_ptr<op::Constant> scale,
539+
shared_ptr<op::Constant> offset)
540+
{
541+
auto out_shape = constant->get_shape();
542+
vector<QUANT> out_vec(shape_size(out_shape));
543+
544+
runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(),
545+
scale->get_vector<REAL>().data(),
546+
offset->get_vector<QUANT>().data(),
547+
out_vec.data(),
548+
constant->get_shape(),
549+
scale->get_shape(),
550+
quant->get_axes());
551+
552+
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
553+
}
554+
555+
void ngraph::pass::ConstantFolding::construct_constant_quantize()
556+
{
557+
auto constant_label =
558+
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
559+
auto q_scale = op::Constant::create(element::f32, Shape{}, {1});
560+
auto q_offset = op::Constant::create(element::i8, Shape{}, {0});
561+
auto mode = op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO;
562+
auto quant_op =
563+
make_shared<op::Quantize>(constant_label, q_scale, q_offset, element::i8, AxisSet{}, mode);
564+
auto quant = make_shared<pattern::op::Label>(quant_op, nullptr, NodeVector{quant_op});
565+
566+
auto constant_quantize_callback = [constant_label, quant](pattern::Matcher& m) {
567+
NGRAPH_DEBUG << "In callback for constant_quantize_callback against node = "
568+
<< m.get_match_root()->get_name();
569+
570+
auto pattern_map = m.get_pattern_map();
571+
572+
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
573+
auto quant_match = pattern_map[quant];
574+
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
575+
auto args = quant_match->get_arguments();
576+
auto scale = static_pointer_cast<op::Constant>(args[1]);
577+
auto offset = static_pointer_cast<op::Constant>(args[2]);
578+
579+
auto type = quant_match->get_element_type();
580+
581+
if (constant_match->get_element_type() != element::f32)
582+
{
583+
return false;
584+
}
585+
586+
if (quantize_op->get_round_mode() != op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO)
587+
{
588+
return false;
589+
}
590+
591+
if (type == element::u8)
592+
{
593+
replace_node(
594+
m.get_match_root(),
595+
make_constant_quantize<float, uint8_t>(constant_match, quantize_op, scale, offset));
596+
return true;
597+
}
598+
else if (type == element::i8)
599+
{
600+
replace_node(
601+
m.get_match_root(),
602+
make_constant_quantize<float, int8_t>(constant_match, quantize_op, scale, offset));
603+
return true;
604+
}
605+
606+
return false;
607+
};
608+
609+
auto quantize_matcher = make_shared<pattern::Matcher>(quant, constant_quantize_callback);
610+
this->add_matcher(quantize_matcher);
611+
}

src/ngraph/pass/constant_folding.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
3535
PAD,
3636
DEQUANTIZE,
3737
UNARY,
38-
BINARY
38+
BINARY,
39+
QUANTIZE
3940
};
4041

4142
public:
@@ -47,6 +48,7 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
4748
construct_constant_pad();
4849
construct_constant_unary();
4950
construct_constant_binary();
51+
construct_constant_quantize();
5052
construct_constant_dequantize();
5153
}
5254

@@ -65,6 +67,7 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
6567
case CFTransformations::UNARY: construct_constant_unary(); break;
6668
case CFTransformations::BINARY: construct_constant_binary(); break;
6769
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
70+
case CFTransformations::QUANTIZE: construct_constant_quantize(); break;
6871
}
6972
}
7073
}
@@ -75,5 +78,6 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
7578
void construct_constant_pad();
7679
void construct_constant_unary();
7780
void construct_constant_binary();
81+
void construct_constant_quantize();
7882
void construct_constant_dequantize();
7983
};

test/constant_folding.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,38 @@ TEST(constant_folding, const_dequantize)
250250
vector<output_c_type> values_dequantize{0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12};
251251
ASSERT_EQ(values_dequantize, values_out);
252252
}
253+
254+
TEST(constant_folding, const_quantize)
255+
{
256+
Shape input_shape{12};
257+
Shape scale_offset_shape;
258+
AxisSet quantization_axes;
259+
260+
auto quant_type = element::u8;
261+
auto output_type = element::u8;
262+
typedef uint8_t output_c_type;
263+
264+
vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
265+
auto constant = op::Constant::create(element::f32, input_shape, values_in);
266+
auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
267+
auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
268+
auto mode = op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO;
269+
auto quantize =
270+
make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
271+
auto f = make_shared<Function>(quantize, op::ParameterVector{});
272+
273+
pass::Manager pass_manager;
274+
pass_manager.register_pass<pass::ConstantFolding>();
275+
pass_manager.run_passes(f);
276+
277+
ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
278+
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
279+
280+
auto new_const =
281+
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
282+
ASSERT_TRUE(new_const);
283+
auto values_out = new_const->get_vector<output_c_type>();
284+
285+
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
286+
ASSERT_EQ(values_quantize, values_out);
287+
}

0 commit comments

Comments
 (0)