Skip to content

Commit de52a13

Browse files
committed
matmul in fp32
1 parent 7ac93d3 commit de52a13

File tree

7 files changed

+29
-39
lines changed

7 files changed

+29
-39
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
212212
} else {
213213
m_op_case = 1;
214214
}
215+
break;
215216
}
216217
default:
217218
break;

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
139139
std::vector<std::string> m_output_names;
140140
std::string m_op_name;
141141
mutable std::string m_name;
142-
int m_op_case;
142+
int m_op_case = 0;
143143
std::vector<std::pair<std::string, std::string>> m_op_node_name;
144144
std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs;
145145
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;

ggml/src/ggml-openvino/openvino/op/mulmat.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
2929
ov::Output<Node> res;
3030
ov::Output<ov::Node> B = context.get_input(0);
3131
ov::Output<ov::Node> A = context.get_input(1);
32-
if (context.get_op_case() == 1) {
33-
if (context.get_input_type(0) == ov::element::f16) {
34-
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), ov::element::f32);
35-
}
36-
if (context.get_input_type(1) == ov::element::f16) {
37-
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), ov::element::f32);
38-
}
39-
} else {
40-
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
32+
if (context.get_input_type(0) != context.get_input_type(1)) {
33+
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
4134
}
4235

4336
auto B_shape = context.get_input_shape(0).to_shape();
@@ -72,8 +65,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
7265
A = Z;
7366
}
7467

75-
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
76-
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
68+
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
7769

7870
return rename_outputs_with_suffix({res}, context.get_name());
7971
}

ggml/src/ggml-openvino/openvino/op/soft_max.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ OutputVector translate_soft_max(const NodeContext& context) {
5757
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
5858
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
5959
// can be fused into SDPA.
60-
if (input_node->get_type_info() == ov::op::v0::Convert::get_type_info_static()) {
61-
auto qk = input_node->get_input_node_shared_ptr(0);
62-
if (qk->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
63-
token_len = get_dimensions(qk->get_input_node_shared_ptr(0), {1});
64-
}
60+
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
61+
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
6562
}
6663
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
6764
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});

ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <openvino/op/softmax.hpp>
1111
#include <openvino/op/transpose.hpp>
1212
#include <openvino/pass/pattern/op/label.hpp>
13+
#include <openvino/pass/pattern/op/optional.hpp>
1314
#include <openvino/pass/pattern/op/pattern.hpp>
1415
#include <openvino/pass/pattern/op/wrap_type.hpp>
1516

@@ -22,15 +23,13 @@ FuseToSDPA::FuseToSDPA() {
2223
const auto m_k = ov::pass::pattern::any_input();
2324
const auto m_q = ov::pass::pattern::any_input();
2425
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
25-
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
2626
const auto m_scale = ov::pass::pattern::any_input();
27-
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
27+
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
2828
const auto m_mask = ov::pass::pattern::any_input();
2929
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
3030
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
31-
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
3231
const auto m_v = ov::pass::pattern::any_input();
33-
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
32+
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
3433

3534
const auto callback = [=](ov::pass::pattern::Matcher& m) {
3635
auto& pattern_to_output = m.get_pattern_value_map();
@@ -42,9 +41,7 @@ FuseToSDPA::FuseToSDPA() {
4241

4342
auto v_trans =
4443
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
45-
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
46-
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
47-
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
44+
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
4845

4946
ov::replace_node(m.get_match_root(), sdpa);
5047
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <openvino/op/unsqueeze.hpp>
2323
#include <openvino/pass/constant_folding.hpp>
2424
#include <openvino/pass/make_stateful.hpp>
25-
#include <transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp>
25+
#include <openvino/core/preprocess/pre_post_process.hpp>
2626

2727
#include "ggml-openvino/openvino/node_context.hpp"
2828
#include "ggml-openvino/openvino/utils.hpp"
@@ -254,22 +254,25 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
254254
return resulting_model;
255255
}
256256

257-
void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model) {
257+
std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<Model> model) {
258258
auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder();
259+
{
260+
ov::pass::Manager manager;
261+
manager.set_per_pass_validation(true);
262+
263+
if (!ggml_model_decoder->is_static()) {
264+
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
265+
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
266+
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
267+
}
259268

260-
ov::pass::Manager manager;
261-
manager.set_per_pass_validation(true);
262-
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
263-
manager.register_pass<ov::pass::ConstantFolding>();
264-
265-
if (!ggml_model_decoder->is_static()) {
266-
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
267-
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
268-
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
269+
// SDPA is even worse on performance
270+
// manager.register_pass<pass::FuseToSDPA>();
271+
manager.run_passes(model);
269272
}
270-
271-
manager.register_pass<pass::FuseToSDPA>();
272-
manager.run_passes(model);
273+
auto preprocessor = ov::preprocess::PrePostProcessor(model);
274+
model = preprocessor.build();
275+
return model;
273276
}
274277

275278
} // namespace ggml

ggml/src/ggml-openvino/openvino/translate_session.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TranslateSession {
1616
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
1717

1818
private:
19-
void apply_transformations(const std::shared_ptr<Model>& model);
19+
std::shared_ptr<Model> apply_transformations(std::shared_ptr<Model> model);
2020
const frontend::InputModel::Ptr m_input_model;
2121
const std::unordered_map<std::string, CreatorFunction>& m_translator_map;
2222
std::shared_ptr<Model> m_ov_model;

0 commit comments

Comments
 (0)