Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/saber/test_saber_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class TestSaberBase{
for(int input_index = 0; input_index < _inputs_dev.size(); ++input_index){
_base_op.init(_inputs_dev[input_index], _outputs_dev[input_index],
_params[param_index], strategy, implenum, ctx);
for(int iter=0; iter<100; ++iter){
for(int iter=0; iter<_gpu_iters; ++iter){
_outputs_dev[input_index][0]->copy_from(*_outputs_host[input_index][0]);
status= _base_op(_inputs_dev[input_index], _outputs_dev[input_index],
_params[param_index], ctx);
Expand Down Expand Up @@ -325,14 +325,14 @@ class TestSaberBase{

std :: vector<std :: string> runtype{"STATIC", "RUNTIME", "SPECIFY"};
std :: vector<std :: string> impltype{"VENDER", "SABER"};
get_cpu_result(CpuFunc);//first get cpu result
for(auto strate : {SPECIFY, RUNTIME, STATIC}){
for(auto implenum : {VENDER_IMPL, SABER_IMPL}){
LOG(INFO) << "TESTING: strategy:" << runtype[strate-1] << ",impltype:" << impltype[(int)implenum];
if(get_op_result(strate, implenum) == SaberUnImplError){
LOG(INFO) << "Unimpl!!";
continue;
}
get_cpu_result(CpuFunc);
result_check_accuracy(succ_ratio);
}
}
Expand All @@ -342,6 +342,9 @@ class TestSaberBase{
void set_random_output(bool random_output) {
_use_random_output = random_output;
}
void set_gpu_iters(int iters){
_gpu_iters = iters;
}
private:
int _op_input_num;
int _op_output_num;
Expand All @@ -358,6 +361,7 @@ class TestSaberBase{
std :: vector<std::vector<Shape>> _input_shapes;
std :: vector<Param_t> _params;
bool _use_random_output{false};
int _gpu_iters{1};
};//testsaberbase
}//namespace saber
}//namespace anakin
Expand Down