@@ -7,10 +7,10 @@ namespace anakin {
77
88template <typename Ttype, DataType Dtype, Precision Ptype, OpRunType RunType>
99Net<Ttype, Dtype, Ptype, RunType>::~Net () {
10- if (_graph_p) {
11- delete _graph_p;
12- _graph_p = nullptr ;
13- }
10+ if (_graph_p) {
11+ delete _graph_p;
12+ _graph_p = nullptr ;
13+ }
1414}
1515
1616template <typename Ttype, DataType Dtype>
@@ -24,7 +24,7 @@ double tensor_average(Tensor4dPtr<Ttype, Dtype>& out_tensor_p) {
2424 tensorptr.h_tensor ().copy_from (*out_tensor_p);
2525 hptr = tensorptr.h_tensor ().data ();
2626 for (int i=0 ; i<out_tensor_p->valid_size (); i++) {
27- sum += hptr[i];
27+ sum += hptr[i];
2828 }
2929 return sum/out_tensor_p->valid_size ();
3030}
@@ -138,8 +138,8 @@ void Net<Ttype, Dtype, Ptype, RunType>::init(graph::Graph<Ttype, Dtype, Ptype>&
138138 init_env (graph);
139139 // shallow copy
140140 _graph_p->CopyFrom (graph);
141-
142- double curr_mem_in_mb_start = MemoryInfo<Ttype>::Global ().get_used_mem_in_mb ();
141+
142+ double curr_mem_in_mb_start = MemoryInfo<Ttype>::Global ().get_used_mem_in_mb ();
143143
144144 auto node_names_in_exec_order = graph.get_nodes_in_order ();
145145 // infer basic shape and parsing parameter from graph
@@ -190,18 +190,24 @@ void Net<Ttype, Dtype, Ptype, RunType>::init(graph::Graph<Ttype, Dtype, Ptype>&
190190 if (node_ptr->get_op_name () == " ConvBatchnormScale" ||
191191 node_ptr->get_op_name () == " ConvBatchnormScaleRelu" || node_ptr->get_op_name () == " ConvRelu" ||
192192 node_ptr->get_op_name () == " Convolution" ) {
193- std::string group = " group" ;
193+ std::string group = " group" ;
194194 auto group_val = node_ptr->template get_attr <int >(group);
195+ std::string dilation = " dilation_rate" ;
196+ auto dilation_rate_val = node_ptr->template get_attr <PTuple<int > >(dilation);
195197 using pblock_type = PBlock<typename DataTypeWarpper<Dtype>::type, Ttype>;
196198 std::string weight_name = " weight_1" ;
197199 auto weights = node_ptr->template get_attr <pblock_type>(weight_name);
198- // int c = weights.d_tensor().channel();
199-
200- if ((group_val == 1 )) {
201- node_ptr->set_op (OpFactory<Ttype, Dtype, Ptype>::Global ()[" Sass" + node_ptr->get_op_name ()]);
202- node_ptr->get_op_name () = " Sass" + node_ptr->get_op_name ();
203- } else {
204- LOG (ERROR) << " node_ptr->get_op_name() sass not support yet." ;
200+
201+ int k_w = weights.d_tensor ().width ();
202+ int k_h = weights.d_tensor ().height ();
203+ int dil_h = dilation_rate_val.vector ()[0 ];
204+ int dil_w = dilation_rate_val.vector ()[1 ];
205+
206+ if ((group_val == 1 ) && (k_w == 3 && k_h == 3 && dil_h == 1 && dil_w == 1 )) {
207+ node_ptr->set_op (OpFactory<Ttype, Dtype, Ptype>::Global ()[" Sass" +node_ptr->get_op_name ()]);
208+ node_ptr->get_op_name () = " Sass" + node_ptr->get_op_name ();
209+ } else {
210+ LOG (ERROR) << " node_ptr->get_op_name() sass not support yet." ;
205211 auto *op_pointer = OpFactory<Ttype, Dtype, Ptype>::Global ()[node_ptr->get_op_name ()];
206212 node_ptr->set_op (op_pointer);
207213 }
@@ -285,16 +291,16 @@ void Net<Ttype, Dtype, Ptype, RunType>::init(graph::Graph<Ttype, Dtype, Ptype>&
285291#endif
286292 }
287293
288- double curr_mem_in_mb_end = MemoryInfo<Ttype>::Global ().get_used_mem_in_mb ();
289- this ->_graph_p ->statistics .template set_info <graph::SYSTEM_MEM>(curr_mem_in_mb_end - curr_mem_in_mb_start);
294+ double curr_mem_in_mb_end = MemoryInfo<Ttype>::Global ().get_used_mem_in_mb ();
295+ this ->_graph_p ->statistics .template set_info <graph::SYSTEM_MEM>(curr_mem_in_mb_end - curr_mem_in_mb_start);
290296 // init memory of _graph_p
291297 init_memory ();
292-
293- graph.statistics = _graph_p->statistics ; // copy statistic back
294- LOG (INFO) << " Temp mem used: " << this ->_graph_p ->statistics .template get_info <graph::TEMP_MEM>() << " MB" ;
295- LOG (INFO) << " Original mem used: " << this ->_graph_p ->statistics .template get_info <graph::ORI_TEMP_MEM>() << " MB" ;
296- LOG (INFO) << " Model mem used: " << this ->_graph_p ->statistics .template get_info <graph::MODEL_MEM>() << " MB" ;
297- LOG (INFO) << " System mem used: " << this ->_graph_p ->statistics .template get_info <graph::SYSTEM_MEM>() << " MB" ;
298+
299+ graph.statistics = _graph_p->statistics ; // copy statistic back
300+ LOG (INFO) << " Temp mem used: " << this ->_graph_p ->statistics .template get_info <graph::TEMP_MEM>() << " MB" ;
301+ LOG (INFO) << " Original mem used: " << this ->_graph_p ->statistics .template get_info <graph::ORI_TEMP_MEM>() << " MB" ;
302+ LOG (INFO) << " Model mem used: " << this ->_graph_p ->statistics .template get_info <graph::MODEL_MEM>() << " MB" ;
303+ LOG (INFO) << " System mem used: " << this ->_graph_p ->statistics .template get_info <graph::SYSTEM_MEM>() << " MB" ;
298304#ifdef ENABLE_OP_TIMER
299305 _op_time = std::vector<float >(_exec_funcs.size (), 0 .0f );
300306#endif
@@ -312,11 +318,11 @@ void Net<Ttype, Dtype, Ptype, RunType>::init(graph::Graph<Ttype, Dtype, Ptype>&
312318 LOG (WARNING) << " Inspect memory of " << executer.name << " (" << executer.op_name << " ) " ;
313319 executer.infer_shape ();
314320
315- for (auto out : executer.outs ) {
316- LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
317- }
321+ for (auto out : executer.outs ) {
322+ LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
323+ }
318324#ifdef USE_CUDA
319- CUDA_CHECK (cudaDeviceSynchronize ());
325+ CUDA_CHECK (cudaDeviceSynchronize ());
320326 CUDA_CHECK (cudaPeekAtLastError ());
321327#endif
322328 }
@@ -344,15 +350,15 @@ void Net<Ttype, Dtype, Ptype, RunType>::prediction() {
344350 << " " << in->valid_shape ()[1 ]
345351 << " " << in->valid_shape ()[2 ]
346352 << " " << in->valid_shape ()[3 ]
347- << " valid_size: " << in->valid_size ()
348- << " realsize: " << in->size ()
353+ << " valid_size: " << in->valid_size ()
354+ << " realsize: " << in->size ()
349355 << " offset_size " <<in->get_seq_offset ().size ();
350356 }
351357#endif
352358#ifdef ENABLE_OP_TIMER
353- Context<Ttype> ctx (0 , 0 , 0 );
354- saber::SaberTimer<Ttype> my_time;
355- my_time.start (ctx);
359+ Context<Ttype> ctx (0 , 0 , 0 );
360+ saber::SaberTimer<Ttype> my_time;
361+ my_time.start (ctx);
356362#endif
357363 if (executer.op_name != " Input" ) {
358364 executer.infer_shape ();
@@ -368,35 +374,35 @@ void Net<Ttype, Dtype, Ptype, RunType>::prediction() {
368374 executer.outs [i]->record_event (executer.ctx_p ->get_compute_stream ());
369375 executer.outs [i]->sync ();
370376 }
371- my_time.end (ctx);
377+ my_time.end (ctx);
372378 _op_time[op_id++] += my_time.get_average_ms ();
373379#endif
374- // LOG(INFO)<< "op: " << executer.name<<"(" << executer.op_name <<") === infer+launch time "<<my_time.get_average_ms() << " ms";
380+ // LOG(INFO)<< "op: " << executer.name<<"(" << executer.op_name <<") === infer+launch time "<<my_time.get_average_ms() << " ms";
375381#ifdef ENABLE_DEBUG
376382#ifdef USE_CUDA
377383 CUDA_CHECK (cudaDeviceSynchronize ());
378384 CUDA_CHECK (cudaPeekAtLastError ());
379385#endif
380- for (auto out : executer.outs ) {
381- std::vector<int > offset=out->get_seq_offset ();
382- LOG (INFO)<<" print offset of " <<executer.name <<" ,size = " <<offset.size ();
383- for (int i=0 ;i<offset.size ();++i){
384- LOG (INFO)<<offset[i]<<" ," ;
385- }
386- LOG (INFO)<<" end print offset of " <<executer.name ;
386+ for (auto out : executer.outs ) {
387+ std::vector<int > offset=out->get_seq_offset ();
388+ LOG (INFO)<<" print offset of " <<executer.name <<" ,size = " <<offset.size ();
389+ for (int i=0 ;i<offset.size ();++i){
390+ LOG (INFO)<<offset[i]<<" ," ;
391+ }
392+ LOG (INFO)<<" end print offset of " <<executer.name ;
387393#define RECORD_INNER
388394#if defined(RECORD_INNER) && defined(USE_X86_PLACE)
389- record_tensor_to_file (*out,(" record_" +executer.name ).c_str ());
390- if (executer.name ==" " )
395+ record_tensor_to_file (*out,(" record_" +executer.name ).c_str ());
396+ if (executer.name ==" " )
391397#endif
392398 LOG (INFO) <<executer.name <<" d_tensor_out_p :" <<out->data ();
393399#ifdef USE_X86_PLACE
394400// for (int i = 0; i < 10; ++i) {
395401// std::cout << out->data()[i]<<" ";
396402// }
397403#endif
398- LOG (ERROR) << " |---out avg " << tensor_average (out);
399- }
404+ LOG (ERROR) << " |---out avg " << tensor_average (out);
405+ }
400406
401407#ifdef USE_ARM_PLACE
402408 int idx = 0 ;
@@ -468,15 +474,15 @@ void Net<Ttype, Dtype, Ptype, RunType>::prediction() {
468474
469475template <typename Ttype, DataType Dtype, Precision Ptype, OpRunType RunType>
470476void Net<Ttype, Dtype, Ptype, RunType>::execute_stop_at_node(std::string node_name) {
471- if (_suspended_point==-1 ) {
472- for (int i=0 ; i<_exec_funcs.size (); i++) {
473- if (_exec_funcs[i].name == node_name) {
474- _suspended_point = i;
475- }
476- }
477- }
478- for (int i=0 ; i<_suspended_point; i++) {
479- auto & executer = _exec_funcs[i];
477+ if (_suspended_point==-1 ) {
478+ for (int i=0 ; i<_exec_funcs.size (); i++) {
479+ if (_exec_funcs[i].name == node_name) {
480+ _suspended_point = i;
481+ }
482+ }
483+ }
484+ for (int i=0 ; i<_suspended_point; i++) {
485+ auto & executer = _exec_funcs[i];
480486 if (RunType == OpRunType::SYNC || executer.need_sync ) {
481487 for (int i = 0 ; i < executer.ins .size (); i++) {
482488 // record
@@ -491,37 +497,37 @@ void Net<Ttype, Dtype, Ptype, RunType>::execute_stop_at_node(std::string node_na
491497 << " " << in->valid_shape ()[1 ]
492498 << " " << in->valid_shape ()[2 ]
493499 << " " << in->valid_shape ()[3 ]
494- << " valid_size: " << in->valid_size ()
495- << " realsize: " << in->size ()
496- << " offset_size " <<in->get_seq_offset ().size ();
500+ << " valid_size: " << in->valid_size ()
501+ << " realsize: " << in->size ()
502+ << " offset_size " <<in->get_seq_offset ().size ();
503+ }
504+ for (auto out : executer.outs ) {
505+ LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
497506 }
498- for (auto out : executer.outs ) {
499- LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
500- }
501507
502508#endif
503- if (executer.op_name != " Input" ) {
504- executer.infer_shape ();
505- executer.launch ();
506- }
509+ if (executer.op_name != " Input" ) {
510+ executer.infer_shape ();
511+ executer.launch ();
512+ }
507513
508- for (int i = 0 ; i < executer.outs .size (); i++) {
509- executer.outs [i]->record_event (executer.ctx_p ->get_compute_stream ());
510- }
511- }
514+ for (int i = 0 ; i < executer.outs .size (); i++) {
515+ executer.outs [i]->record_event (executer.ctx_p ->get_compute_stream ());
516+ }
517+ }
512518}
513519
514520template <typename Ttype, DataType Dtype, Precision Ptype, OpRunType RunType>
515521void Net<Ttype, Dtype, Ptype, RunType>::execute_start_from_node(std::string node_name) {
516- if (_start_point == -1 ) {
517- for (int i=0 ; i<_exec_funcs.size (); i++) {
518- if (_exec_funcs[i].name == node_name) {
519- _start_point = i;
520- }
521- }
522- }
523- for (int i=_start_point; i<_exec_funcs.size (); i++) {
524- auto & executer = _exec_funcs[i];
522+ if (_start_point == -1 ) {
523+ for (int i=0 ; i<_exec_funcs.size (); i++) {
524+ if (_exec_funcs[i].name == node_name) {
525+ _start_point = i;
526+ }
527+ }
528+ }
529+ for (int i=_start_point; i<_exec_funcs.size (); i++) {
530+ auto & executer = _exec_funcs[i];
525531 if (RunType == OpRunType::SYNC || executer.need_sync ) {
526532 for (int i = 0 ; i < executer.ins .size (); i++) {
527533 // record
@@ -536,24 +542,24 @@ void Net<Ttype, Dtype, Ptype, RunType>::execute_start_from_node(std::string node
536542 << " " << in->valid_shape ()[1 ]
537543 << " " << in->valid_shape ()[2 ]
538544 << " " << in->valid_shape ()[3 ]
539- << " valid_size: " << in->valid_size ()
540- << " realsize: " << in->size ()
541- << " offset_size " <<in->get_seq_offset ().size ();
545+ << " valid_size: " << in->valid_size ()
546+ << " realsize: " << in->size ()
547+ << " offset_size " <<in->get_seq_offset ().size ();
548+ }
549+ for (auto out : executer.outs ) {
550+ LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
542551 }
543- for (auto out : executer.outs ) {
544- LOG (INFO) << " |-- out tensor avg " << tensor_average (out);
545- }
546552
547553#endif
548- if (executer.op_name != " Input" ) {
549- executer.infer_shape ();
550- executer.launch ();
551- }
554+ if (executer.op_name != " Input" ) {
555+ executer.infer_shape ();
556+ executer.launch ();
557+ }
552558
553- for (int i = 0 ; i < executer.outs .size (); i++) {
554- executer.outs [i]->record_event (executer.ctx_p ->get_compute_stream ());
555- }
556- }
559+ for (int i = 0 ; i < executer.outs .size (); i++) {
560+ executer.outs [i]->record_event (executer.ctx_p ->get_compute_stream ());
561+ }
562+ }
557563}
558564
559565template <typename Ttype, DataType Dtype, Precision Ptype, OpRunType RunType>
@@ -607,27 +613,27 @@ Status Net<Ttype, Dtype, Ptype, RunType>::init_memory() {
607613 auto share_memory = [this ](graph::Edge<Ttype, Dtype>& edge) {
608614 if (edge.shared ()) {
609615 auto & edge_name = edge.share_from ();
610- bool continue_search = true ;
611- while (continue_search) {
612- auto match_edge = [&](graph::Edge<Ttype, Dtype>& inner_edge) {
613- if (inner_edge.name () == edge_name) {
614- if (inner_edge.shared ()) {
615- edge_name = inner_edge.share_from ();
616- return Status::EXIT (" Continue to find next . " );
617- }
618- if (inner_edge.weight ()->size () < edge.weight ()->valid_size ()) {
619- auto inner_original_shape = inner_edge.weight ()->valid_shape ();
620- inner_edge.weight ()->re_alloc (edge.weight ()->valid_shape ());
621- inner_edge.weight ()->set_shape (inner_original_shape, inner_edge.weight ()->shape ());
622- }
623- edge.weight ()->share_from (*(inner_edge.weight ()));
624- continue_search = false ;
625- return Status::EXIT (" Find the matched target edge. " );
626- }
627- return Status::OK ();
628- };
629- this ->_graph_p ->Scanner ->BFS_Edge (match_edge);
630- }
616+ bool continue_search = true ;
617+ while (continue_search) {
618+ auto match_edge = [&](graph::Edge<Ttype, Dtype>& inner_edge) {
619+ if (inner_edge.name () == edge_name) {
620+ if (inner_edge.shared ()) {
621+ edge_name = inner_edge.share_from ();
622+ return Status::EXIT (" Continue to find next . " );
623+ }
624+ if (inner_edge.weight ()->size () < edge.weight ()->valid_size ()) {
625+ auto inner_original_shape = inner_edge.weight ()->valid_shape ();
626+ inner_edge.weight ()->re_alloc (edge.weight ()->valid_shape ());
627+ inner_edge.weight ()->set_shape (inner_original_shape, inner_edge.weight ()->shape ());
628+ }
629+ edge.weight ()->share_from (*(inner_edge.weight ()));
630+ continue_search = false ;
631+ return Status::EXIT (" Find the matched target edge. " );
632+ }
633+ return Status::OK ();
634+ };
635+ this ->_graph_p ->Scanner ->BFS_Edge (match_edge);
636+ }
631637 }
632638 };
633639 _graph_p->Scanner ->BFS_Edge (share_memory);
@@ -644,8 +650,8 @@ Status Net<Ttype, Dtype, Ptype, RunType>::init_memory() {
644650 };
645651 this ->_graph_p ->Scanner ->BFS_Edge (analysis_used_of_temp_mem);
646652
647- this ->_graph_p ->statistics .template set_info <graph::TEMP_MEM>(temp_mem_in_mbytes / 1e6 );
648- this ->_graph_p ->statistics .template set_info <graph::ORI_TEMP_MEM>(ori_temp_mem_in_mbytes / 1e6 );
653+ this ->_graph_p ->statistics .template set_info <graph::TEMP_MEM>(temp_mem_in_mbytes / 1e6 );
654+ this ->_graph_p ->statistics .template set_info <graph::ORI_TEMP_MEM>(ori_temp_mem_in_mbytes / 1e6 );
649655 }
650656 return Status::OK ();
651657}
@@ -700,4 +706,3 @@ template class Net<ARM, AK_FLOAT, Precision::INT8, OpRunType::SYNC>;
700706#endif // arm
701707
702708} /* namespace anakin */
703-
0 commit comments