|
29 | 29 | #include "ngraph/op/multiply.hpp" |
30 | 30 | #include "ngraph/op/negative.hpp" |
31 | 31 | #include "ngraph/op/pad.hpp" |
| 32 | +#include "ngraph/op/quantize.hpp" |
32 | 33 | #include "ngraph/op/reshape.hpp" |
33 | 34 | #include "ngraph/op/subtract.hpp" |
34 | 35 | #include "ngraph/pattern/matcher.hpp" |
|
43 | 44 | #include "ngraph/runtime/reference/multiply.hpp" |
44 | 45 | #include "ngraph/runtime/reference/negate.hpp" |
45 | 46 | #include "ngraph/runtime/reference/pad.hpp" |
| 47 | +#include "ngraph/runtime/reference/quantize.hpp" |
46 | 48 | #include "ngraph/runtime/reference/reshape.hpp" |
47 | 49 | #include "ngraph/runtime/reference/subtract.hpp" |
48 | 50 |
|
@@ -529,3 +531,81 @@ void ngraph::pass::ConstantFolding::construct_constant_dequantize() |
529 | 531 | auto dequantize_matcher = make_shared<pattern::Matcher>(dequant, constant_dequantize_callback); |
530 | 532 | this->add_matcher(dequantize_matcher); |
531 | 533 | } |
| 534 | + |
| 535 | +template <class REAL, class QUANT> |
| 536 | +shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constant, |
| 537 | + shared_ptr<op::Quantize> quant, |
| 538 | + shared_ptr<op::Constant> scale, |
| 539 | + shared_ptr<op::Constant> offset) |
| 540 | +{ |
| 541 | + auto out_shape = constant->get_shape(); |
| 542 | + vector<QUANT> out_vec(shape_size(out_shape)); |
| 543 | + |
| 544 | + runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(), |
| 545 | + scale->get_vector<REAL>().data(), |
| 546 | + offset->get_vector<QUANT>().data(), |
| 547 | + out_vec.data(), |
| 548 | + constant->get_shape(), |
| 549 | + scale->get_shape(), |
| 550 | + quant->get_axes()); |
| 551 | + |
| 552 | + return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec); |
| 553 | +} |
| 554 | + |
| 555 | +void ngraph::pass::ConstantFolding::construct_constant_quantize() |
| 556 | +{ |
| 557 | + auto constant_label = |
| 558 | + make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>()); |
| 559 | + auto q_scale = op::Constant::create(element::f32, Shape{}, {1}); |
| 560 | + auto q_offset = op::Constant::create(element::i8, Shape{}, {0}); |
| 561 | + auto mode = op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO; |
| 562 | + auto quant_op = |
| 563 | + make_shared<op::Quantize>(constant_label, q_scale, q_offset, element::i8, AxisSet{}, mode); |
| 564 | + auto quant = make_shared<pattern::op::Label>(quant_op, nullptr, NodeVector{quant_op}); |
| 565 | + |
| 566 | + auto constant_quantize_callback = [constant_label, quant](pattern::Matcher& m) { |
| 567 | + NGRAPH_DEBUG << "In callback for constant_quantize_callback against node = " |
| 568 | + << m.get_match_root()->get_name(); |
| 569 | + |
| 570 | + auto pattern_map = m.get_pattern_map(); |
| 571 | + |
| 572 | + auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); |
| 573 | + auto quant_match = pattern_map[quant]; |
| 574 | + auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match); |
| 575 | + auto args = quant_match->get_arguments(); |
| 576 | + auto scale = static_pointer_cast<op::Constant>(args[1]); |
| 577 | + auto offset = static_pointer_cast<op::Constant>(args[2]); |
| 578 | + |
| 579 | + auto type = quant_match->get_element_type(); |
| 580 | + |
| 581 | + if (constant_match->get_element_type() != element::f32) |
| 582 | + { |
| 583 | + return false; |
| 584 | + } |
| 585 | + |
| 586 | + if (quantize_op->get_round_mode() != op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO) |
| 587 | + { |
| 588 | + return false; |
| 589 | + } |
| 590 | + |
| 591 | + if (type == element::u8) |
| 592 | + { |
| 593 | + replace_node( |
| 594 | + m.get_match_root(), |
| 595 | + make_constant_quantize<float, uint8_t>(constant_match, quantize_op, scale, offset)); |
| 596 | + return true; |
| 597 | + } |
| 598 | + else if (type == element::i8) |
| 599 | + { |
| 600 | + replace_node( |
| 601 | + m.get_match_root(), |
| 602 | + make_constant_quantize<float, int8_t>(constant_match, quantize_op, scale, offset)); |
| 603 | + return true; |
| 604 | + } |
| 605 | + |
| 606 | + return false; |
| 607 | + }; |
| 608 | + |
| 609 | + auto quantize_matcher = make_shared<pattern::Matcher>(quant, constant_quantize_callback); |
| 610 | + this->add_matcher(quantize_matcher); |
| 611 | +} |
0 commit comments