Skip to content

Commit ff401b8

Browse files
cavusmustafawine99
authored andcommitted
add mark decomp pass
1 parent 2e1c54e commit ff401b8

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "mark_decompression_convert_constant_folding.hpp"
4+
#include "openvino/pass/matcher_pass.hpp"
5+
#include "openvino/core/visibility.hpp"
6+
7+
#ifdef OPENVINO_STATIC_LIBRARY
8+
# define TRANSFORMATIONS_API
9+
#else
10+
# ifdef IMPLEMENT_OPENVINO_API
11+
# define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS
12+
# else
13+
# define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS
14+
# endif // IMPLEMENT_OPENVINO_API
15+
#endif // OPENVINO_STATIC_LIBRARY
16+
17+
namespace ov {
18+
namespace pass {
19+
20+
class TRANSFORMATIONS_API MarkCompressedFloatConstants;
21+
22+
} // namespace pass
23+
} // namespace ov
24+
25+
class ov::pass::MarkCompressedFloatConstants : public MatcherPass {
26+
public:
27+
OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants");
28+
MarkCompressedFloatConstants();
29+
};

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "ggml-openvino/openvino/utils.hpp"
2929
#include "input_model.hpp"
3030
#include "pass/fuse_to_sdpa.hpp"
31+
#include "pass/mark_decompression_convert_constant_folding.hpp"
3132

3233
namespace ov {
3334
namespace frontend {
@@ -259,6 +260,8 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
259260
{
260261
ov::pass::Manager manager;
261262
manager.set_per_pass_validation(true);
263+
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
264+
manager.register_pass<ov::pass::ConstantFolding>();
262265

263266
if (!ggml_model_decoder->is_static()) {
264267
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
@@ -267,7 +270,7 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
267270
}
268271

269272
// SDPA is even worse on performance
270-
// manager.register_pass<pass::FuseToSDPA>();
273+
manager.register_pass<pass::FuseToSDPA>();
271274
manager.run_passes(model);
272275
}
273276
auto preprocessor = ov::preprocess::PrePostProcessor(model);

0 commit comments

Comments
 (0)