Skip to content

Commit 7ac93d3

Browse files
committed
temp. changes for mark decomp
1 parent a8fa0e5 commit 7ac93d3

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

ggml/src/ggml-openvino/openvino/op/mulmat.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@ OutputVector translate_mulmat(const NodeContext& context) {
2828

2929
ov::Output<Node> res;
3030
ov::Output<ov::Node> B = context.get_input(0);
31-
ov::Output<ov::Node> A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
31+
ov::Output<ov::Node> A = context.get_input(1);
32+
if (context.get_op_case() == 1) {
33+
if (context.get_input_type(0) == ov::element::f16) {
34+
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), ov::element::f32);
35+
}
36+
if (context.get_input_type(1) == ov::element::f16) {
37+
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), ov::element::f32);
38+
}
39+
} else {
40+
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
41+
}
3242

3343
auto B_shape = context.get_input_shape(0).to_shape();
3444
auto A_shape = context.get_input_shape(1).to_shape();

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <openvino/op/unsqueeze.hpp>
2323
#include <openvino/pass/constant_folding.hpp>
2424
#include <openvino/pass/make_stateful.hpp>
25+
#include <transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp>
2526

2627
#include "ggml-openvino/openvino/node_context.hpp"
2728
#include "ggml-openvino/openvino/utils.hpp"
@@ -258,6 +259,7 @@ void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model
258259

259260
ov::pass::Manager manager;
260261
manager.set_per_pass_validation(true);
262+
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
261263
manager.register_pass<ov::pass::ConstantFolding>();
262264

263265
if (!ggml_model_decoder->is_static()) {

0 commit comments

Comments
 (0)