Skip to content

Commit fcc7d16

Browse files
committed
Fix test-thread-safety
1 parent 20e203e commit fcc7d16

File tree

2 files changed

+81
-77
lines changed

2 files changed

+81
-77
lines changed

ggml/src/ggml-openvino/utils.cpp

Lines changed: 80 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cstdint>
88
#include <cstdlib>
99
#include <memory>
10+
#include <mutex>
1011
#include <openvino/core/any.hpp>
1112
#include <openvino/core/graph_util.hpp>
1213
#include <openvino/core/type/float16.hpp>
@@ -96,6 +97,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
9697
core.set_property(ov::cache_dir(cache_dir));
9798
}
9899

100+
static std::mutex cache_mutex;
99101
static std::unordered_map<struct ggml_cgraph*, std::shared_ptr<ov::InferRequest>> infer_request_cache;
100102
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_input_names_cache;
101103
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_output_names_cache;
@@ -109,89 +111,93 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
109111
int64_t conversion_end_time;
110112
int64_t compile_end_time;
111113

112-
auto it = infer_request_cache.find(cgraph);
113-
if (it != infer_request_cache.end()) {
114-
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
115-
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
116-
decoder_end_time = ggml_time_us();
117-
118-
// For NPU for the first time we call kvcache modle, pop the compiled kvcache model from cache
119-
if (is_static && compiled_model_cache.find(cgraph) != compiled_model_cache.end()) {
120-
infer_request_cache[cgraph] =
121-
std::make_shared<ov::InferRequest>(compiled_model_cache[cgraph].create_infer_request());
122-
compiled_model_cache.erase(cgraph);
123-
}
124-
infer_request = *infer_request_cache[cgraph];
125-
126-
conversion_end_time = ggml_time_us();
127-
compile_end_time = conversion_end_time;
128-
} else {
129-
std::shared_ptr<ov::Model> model;
130-
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
114+
{
115+
std::lock_guard<std::mutex> lock(cache_mutex);
131116

132-
if (is_static) {
133-
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
134-
auto ggml_decoder_kvcache = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
117+
auto it = infer_request_cache.find(cgraph);
118+
if (it != infer_request_cache.end()) {
119+
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
120+
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
135121
decoder_end_time = ggml_time_us();
136122

137-
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
138-
auto input_model_kvcache = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_kvcache);
139-
140-
model = ov::frontend::ggml::FrontEnd::convert(input_model);
141-
ggml_decoder->clear_model_weights();
142-
auto model_kvcache = ov::frontend::ggml::FrontEnd::convert(input_model_kvcache);
143-
ggml_decoder_kvcache->clear_model_weights();
144-
conversion_end_time = ggml_time_us();
145-
146-
auto compiled_model = core.compile_model(model, device, config);
147-
auto compiled_model_kvcache = core.compile_model(model_kvcache, device, config);
148-
compiled_model_cache[cgraph] = compiled_model_kvcache;
149-
compile_end_time = ggml_time_us();
150-
151-
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
152-
infer_request = *infer_request_cache[cgraph];
153-
compiled_model_cache[cgraph] = compiled_model_kvcache;
154-
155-
if (getenv("GGML_OPENVINO_DUMP_IR")) {
156-
char timestamped_filename[64];
157-
auto timestamp = (long long) ggml_time_us();
158-
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp);
159-
ov::serialize(model, timestamped_filename);
160-
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_kvcache_%lld.xml", timestamp);
161-
ov::serialize(model_kvcache, timestamped_filename);
123+
// For NPU for the first time we call kvcache modle, pop the compiled kvcache model from cache
124+
if (is_static && compiled_model_cache.find(cgraph) != compiled_model_cache.end()) {
125+
infer_request_cache[cgraph] =
126+
std::make_shared<ov::InferRequest>(compiled_model_cache[cgraph].create_infer_request());
127+
compiled_model_cache.erase(cgraph);
162128
}
163-
} else {
164-
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
165-
decoder_end_time = ggml_time_us();
166-
167-
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
168-
model = ov::frontend::ggml::FrontEnd::convert(input_model);
169-
ggml_decoder->clear_model_weights();
170-
conversion_end_time = ggml_time_us();
171-
172-
auto compiled_model = core.compile_model(model, device, config);
173-
compile_end_time = ggml_time_us();
174-
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
175129
infer_request = *infer_request_cache[cgraph];
176130

177-
if (getenv("GGML_OPENVINO_DUMP_IR")) {
178-
char timestamped_filename[64];
179-
auto timestamp = (long long) ggml_time_us();
180-
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp);
181-
ov::serialize(model, timestamped_filename);
131+
conversion_end_time = ggml_time_us();
132+
compile_end_time = conversion_end_time;
133+
} else {
134+
std::shared_ptr<ov::Model> model;
135+
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
136+
137+
if (is_static) {
138+
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
139+
auto ggml_decoder_kvcache = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, false);
140+
decoder_end_time = ggml_time_us();
141+
142+
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
143+
auto input_model_kvcache = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_kvcache);
144+
145+
model = ov::frontend::ggml::FrontEnd::convert(input_model);
146+
ggml_decoder->clear_model_weights();
147+
auto model_kvcache = ov::frontend::ggml::FrontEnd::convert(input_model_kvcache);
148+
ggml_decoder_kvcache->clear_model_weights();
149+
conversion_end_time = ggml_time_us();
150+
151+
auto compiled_model = core.compile_model(model, device, config);
152+
auto compiled_model_kvcache = core.compile_model(model_kvcache, device, config);
153+
compiled_model_cache[cgraph] = compiled_model_kvcache;
154+
compile_end_time = ggml_time_us();
155+
156+
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
157+
infer_request = *infer_request_cache[cgraph];
158+
compiled_model_cache[cgraph] = compiled_model_kvcache;
159+
160+
if (getenv("GGML_OPENVINO_DUMP_IR")) {
161+
char timestamped_filename[64];
162+
auto timestamp = (long long) ggml_time_us();
163+
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp);
164+
ov::serialize(model, timestamped_filename);
165+
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_kvcache_%lld.xml", timestamp);
166+
ov::serialize(model_kvcache, timestamped_filename);
167+
}
168+
} else {
169+
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights, is_static, true);
170+
decoder_end_time = ggml_time_us();
171+
172+
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
173+
model = ov::frontend::ggml::FrontEnd::convert(input_model);
174+
ggml_decoder->clear_model_weights();
175+
conversion_end_time = ggml_time_us();
176+
177+
auto compiled_model = core.compile_model(model, device, config);
178+
compile_end_time = ggml_time_us();
179+
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
180+
infer_request = *infer_request_cache[cgraph];
181+
182+
if (getenv("GGML_OPENVINO_DUMP_IR")) {
183+
char timestamped_filename[64];
184+
auto timestamp = (long long) ggml_time_us();
185+
snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp);
186+
ov::serialize(model, timestamped_filename);
187+
}
182188
}
183-
}
184189

185-
std::vector<std::string> ov_input_names;
186-
std::vector<std::string> ov_output_names;
187-
for (const auto& ov_param : model->get_parameters()) {
188-
ov_input_names.push_back(ov_param->get_friendly_name());
189-
}
190-
for (const auto& ov_output : model->get_results()) {
191-
ov_output_names.push_back(ov_output->get_friendly_name());
190+
std::vector<std::string> ov_input_names;
191+
std::vector<std::string> ov_output_names;
192+
for (const auto& ov_param : model->get_parameters()) {
193+
ov_input_names.push_back(ov_param->get_friendly_name());
194+
}
195+
for (const auto& ov_output : model->get_results()) {
196+
ov_output_names.push_back(ov_output->get_friendly_name());
197+
}
198+
ov_input_names_cache[cgraph] = ov_input_names;
199+
ov_output_names_cache[cgraph] = ov_output_names;
192200
}
193-
ov_input_names_cache[cgraph] = ov_input_names;
194-
ov_output_names_cache[cgraph] = ov_output_names;
195201
}
196202

197203
auto ov_input_names = ov_input_names_cache[cgraph];

tests/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,7 @@ llama_build_and_test(test-json-partial.cpp)
185185
llama_build_and_test(test-log.cpp)
186186
llama_build_and_test(test-regex-partial.cpp)
187187

188-
if (NOT GGML_OPENVINO)
189-
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2)
190-
endif()
188+
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2)
191189

192190
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
193191
if (NOT WIN32)

0 commit comments

Comments
 (0)