@@ -168,7 +168,7 @@ enum direction{diag, move_left, up};
168168
169169std::vector<std::tuple<int32, int32> > get_best_path (py::array_t <double > array,
170170 const StringVector& words_a,
171- const StringVector& words_b, const bool use_chardiff) {
171+ const StringVector& words_b, const bool use_chardiff, const bool use_fast_edit_distance= true ) {
172172 auto buf = array.request ();
173173 double * cost_mat = (double *) buf.ptr ;
174174 int32_t numr = array.shape ()[0 ], numc = array.shape ()[1 ];
@@ -202,8 +202,89 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
202202 if (alen >= 50 || blen >= 50 ) {
203203 throw std::runtime_error (" Word is too long! Increase buffer" );
204204 }
205- diag_trans_cost =
205+ if (use_fast_edit_distance) {
206+ diag_trans_cost =
206207 calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
208+ } else {
209+ diag_trans_cost =
210+ levdistance (a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
211+ }
212+ } else {
213+ diag_trans_cost = a == b ? 0 . : 1 .;
214+ }
215+
216+ if (isclose (diagc + diag_trans_cost, current_cost)) {
217+ direc = diag;
218+ } else if (isclose (upc + up_trans_cost, current_cost)) {
219+ direc = up;
220+ } else if (isclose (leftc + left_trans_cost, current_cost)) {
221+ direc = move_left;
222+ } else {
223+ std::cout << a <<" " <<b<<" " <<i<<" " <<j<<" trans " <<diag_trans_cost<<" " <<left_trans_cost<<" " <<up_trans_cost<<" costs " <<current_cost<<" " <<diagc<<" " <<leftc<<" " <<upc <<std::endl;
224+ std::cout << (diag_trans_cost + diagc == current_cost) <<std::endl;
225+ std::cout << diag_trans_cost + diagc <<" " <<current_cost <<std::endl;
226+ throw std::runtime_error (" Should not be possible !" );
227+ }
228+ }
229+
230+ if (direc == up) {
231+ i--;
232+ bestpath.emplace_back (i, -1 ); // -1 means null token
233+ } else if (direc == move_left) {
234+ j--;
235+ bestpath.emplace_back (-1 , j);
236+ } else if (direc == diag) {
237+ i--, j--;
238+ bestpath.emplace_back (i, j);
239+ }
240+ }
241+ return bestpath;
242+ }
243+
244+
245+ std::vector<std::tuple<int32, int32> > get_best_path_lists (py::array_t <double > array,
246+ const std::vector<std::string>& words_a,
247+ const std::vector<std::string>& words_b, const bool use_chardiff, const bool use_fast_edit_distance=true ) {
248+ auto buf = array.request ();
249+ double * cost_mat = (double *) buf.ptr ;
250+ int32_t numr = array.shape ()[0 ], numc = array.shape ()[1 ];
251+ std::vector<int32> char_dist_buffer;
252+ if (use_chardiff) {
253+ char_dist_buffer.resize (100 );
254+ }
255+
256+ std::vector<std::tuple<int , int > > bestpath;
257+ int i = numr - 1 , j = numc - 1 ;
258+ while (i != 0 || j != 0 ) {
259+ double upc, leftc, diagc;
260+ direction direc;
261+ if (i == 0 ) {
262+ direc = move_left;
263+ } else if (j == 0 ) {
264+ direc = up;
265+ } else {
266+ float current_cost = cost_mat[i * numc + j];
267+ upc = cost_mat[(i-1 ) * numc + j];
268+ leftc = cost_mat[i * numc + j - 1 ];
269+ diagc = cost_mat[(i-1 ) * numc + j - 1 ];
270+ const std::string& a = words_a[i-1 ];
271+ const std::string& b = words_b[j-1 ];
272+ double up_trans_cost = 1.0 ;
273+ double left_trans_cost = 1.0 ;
274+ double diag_trans_cost;
275+ if (use_chardiff) {
276+ int alen = a.size ();
277+ int blen = b.size ();
278+ if (alen >= 50 || blen >= 50 ) {
279+ throw std::runtime_error (" Word is too long! Increase buffer" );
280+ }
281+ if (use_fast_edit_distance) {
282+ diag_trans_cost =
283+ calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
284+ } else {
285+ diag_trans_cost =
286+ levdistance (a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
287+ }
207288 } else {
208289 diag_trans_cost = a == b ? 0 . : 1 .;
209290 }
@@ -322,12 +403,12 @@ void get_best_path_ctm(py::array_t<double> array, py::list& bestpath_lst, std::v
322403
323404
324405int calc_sum_cost (py::array_t <double > array, const StringVector& words_a,
325- const StringVector& words_b, const bool use_chardist) {
406+ const StringVector& words_b, const bool use_chardist, const bool use_fast_edit_distance= true ) {
326407 if ( array.ndim () != 2 )
327408 throw std::runtime_error (" Input should be 2-D NumPy array" );
328409
329410 int M1 = array.shape ()[0 ], N1 = array.shape ()[1 ];
330- if (M1 - 1 != words_a.Size () || N1 - 1 != words_b.Size ()) throw std::runtime_error (" Sizes do not match!" );
411+ if (M1 - 1 != words_a.size () || N1 - 1 != words_b.size ()) throw std::runtime_error (" Sizes do not match!" );
331412 auto buf = array.request ();
332413 double * ptr = (double *) buf.ptr ;
333414
@@ -350,8 +431,65 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
350431 if (alen >= 50 || blen >= 50 ) {
351432 throw std::runtime_error (" Word is too long! Increase buffer" );
352433 }
353- transition_cost = calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ())
354- / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
434+ if (use_fast_edit_distance) {
435+ transition_cost =
436+ calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
437+ } else {
438+ transition_cost =
439+ levdistance (a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
440+ }
441+ } else {
442+ transition_cost = words_a[i-1 ] == words_b[j-1 ] ? 0 . : 1 .;
443+ }
444+
445+ double upc = ptr[(i-1 ) * N1 + j] + 1 .;
446+ double leftc = ptr[i * N1 + j - 1 ] + 1 .;
447+ double diagc = ptr[(i-1 ) * N1 + j - 1 ] + transition_cost;
448+ double sum = std::min (upc, std::min (leftc, diagc));
449+ ptr[i * N1 + j] = sum;
450+ }
451+ }
452+ return ptr[M1*N1 - 1 ];
453+ }
454+
455+
456+
457+ int calc_sum_cost_lists (py::array_t <double > array, const std::vector<std::string>& words_a,
458+ const std::vector<std::string>& words_b, const bool use_chardist, const bool use_fast_edit_distance=true ) {
459+ if ( array.ndim () != 2 )
460+ throw std::runtime_error (" Input should be 2-D NumPy array" );
461+
462+ int M1 = array.shape ()[0 ], N1 = array.shape ()[1 ];
463+ if (M1 - 1 != words_a.size () || N1 - 1 != words_b.size ()) throw std::runtime_error (" Sizes do not match!" );
464+ auto buf = array.request ();
465+ double * ptr = (double *) buf.ptr ;
466+
467+ std::vector<int32> char_dist_buffer;
468+ if (use_chardist) {
469+ char_dist_buffer.resize (100 );
470+ }
471+
472+ ptr[0 ] = 0 ;
473+ for (int32 i = 1 ; i < M1; i++) ptr[i*N1] = ptr[(i-1 )*N1] + 1 ;
474+ for (int32 j = 1 ; j < N1; j++) ptr[j] = ptr[j-1 ] + 1 ;
475+ for (int32 i = 1 ; i < M1; i++) {
476+ for (int32 j = 1 ; j < N1; j++) {
477+ double transition_cost;
478+ if (use_chardist) {
479+ const std::string& a = words_a[i-1 ];
480+ const std::string& b = words_b[j-1 ];
481+ int alen = a.size ();
482+ int blen = b.size ();
483+ if (alen >= 50 || blen >= 50 ) {
484+ throw std::runtime_error (" Word is too long! Increase buffer" );
485+ }
486+ if (use_fast_edit_distance) {
487+ transition_cost =
488+ calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
489+ } else {
490+ transition_cost =
491+ levdistance (a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
492+ }
355493 } else {
356494 transition_cost = words_a[i-1 ] == words_b[j-1 ] ? 0 . : 1 .;
357495 }
@@ -374,7 +512,7 @@ int calc_sum_cost_ctm(py::array_t<double> array, std::vector<std::string>& texta
374512 throw std::runtime_error (" Input should be 2-D NumPy array" );
375513
376514 int M = array.shape ()[0 ], N = array.shape ()[1 ];
377- if (M != texta.size () || N != textb.size ()) throw std::runtime_error (" Sizes do not match!" );
515+ if (M != texta.size () || N != textb.size ()) throw std::runtime_error (" s do not match!" );
378516 auto buf = array.request ();
379517 double * ptr = (double *) buf.ptr ;
380518// std::cout << "STARTING"<<std::endl;
@@ -438,9 +576,11 @@ void init_stringvector(py::module_ &m);
438576PYBIND11_MODULE (texterrors_align,m) {
439577 m.doc () = " pybind11 plugin" ;
440578 m.def (" calc_sum_cost" , &calc_sum_cost, " Calculate summed cost matrix" );
579+ m.def (" calc_sum_cost_lists" , &calc_sum_cost_lists, " Calculate summed cost matrix" );
441580 m.def (" calc_sum_cost_ctm" , &calc_sum_cost_ctm, " Calculate summed cost matrix" );
442581 m.def (" get_best_path" , &get_best_path, " get_best_path" );
443582 m.def (" get_best_path_ctm" , &get_best_path_ctm, " get_best_path_ctm" );
583+ m.def (" get_best_path_lists" , &get_best_path_lists, " get_best_path_lists" );
444584 m.def (" lev_distance" , lev_distance<int >);
445585 m.def (" lev_distance" , lev_distance<char >);
446586 m.def (" lev_distance_str" , &lev_distance_str);
0 commit comments