@@ -27,6 +27,56 @@ struct Pair {
2727 int16_t j;
2828};
2929
30+
31+ int calc_edit_distance_fast (int32* cost_mat, const char * a, const char * b,
32+ const int32 M, const int32 N) {
33+ int row_length = N+1 ;
34+ // std::cout << "STARTING M="<< M<< " N="<<N<<std::endl;
35+ for (int32 i = 0 ; i <= M; ++i) {
36+ for (int32 j = 0 ; j <= N; ++j) {
37+
38+ if (i == 0 && j == 0 ) {
39+ cost_mat[0 ] = 0 ;
40+ continue ;
41+ }
42+ if (i == 0 ) {
43+ cost_mat[j] = cost_mat[j - 1 ] + 1 ;
44+ continue ;
45+ }
46+ if (j == 0 ) {
47+ cost_mat[row_length] = cost_mat[0 ] + 1 ;
48+ continue ;
49+ }
50+ int32 transition_cost = a[i-1 ] == b[j-1 ] ? 0 : 1 ;
51+
52+ int32 upc = cost_mat[j] + 1 ;
53+ int32 leftc = cost_mat[row_length + j - 1 ] + 1 ;
54+ int32 diagc = cost_mat[j - 1 ] + transition_cost;
55+ int32 cost = std::min (upc, std::min (leftc, diagc) );
56+
57+ cost_mat[row_length + j] = cost;
58+ cost_mat[j - 1 ] = cost_mat[row_length + j - 1 ]; // copying result up after use
59+ }
60+ if (i > 0 ) {
61+ cost_mat[N] = cost_mat[row_length + N];
62+ }
63+
64+ // std::cout << "row "<<i;
65+ // for (int32 j = 0; j <= N; ++j) {
66+ // std::cout << " "<<cost_mat[j];
67+ // }
68+ // std::cout << std::endl;
69+ }
70+ // std::cout << "last row";
71+ // for (int32 j = 0; j <= N; ++j) {
72+ // std::cout <<" "<<cost_mat[row_length + j];
73+ // }
74+ // std::cout << std::endl;
75+
76+ return cost_mat[row_length - 1 ];
77+ }
78+
79+
3080template <class T >
3181void create_lev_cost_mat (int32* cost_mat, const T* a, const T* b,
3282 const int32 M, const int32 N) {
@@ -109,6 +159,11 @@ int lev_distance_str(std::string a, std::string b) {
109159 return levdistance (a.data (), b.data (), a.size (), b.size ());
110160}
111161
162+ int calc_edit_distance_fast_str (std::string a, std::string b) {
163+ std::vector<int > buffer (a.size () + b.size () + 2 );
164+ return calc_edit_distance_fast (buffer.data (), a.data (), b.data (), a.size (), b.size ());
165+ }
166+
112167enum direction{diag, move_left, up};
113168
114169std::vector<std::tuple<int32, int32> > get_best_path (py::array_t <double > array,
@@ -117,6 +172,10 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
117172 auto buf = array.request ();
118173 double * cost_mat = (double *) buf.ptr ;
119174 int32_t numr = array.shape ()[0 ], numc = array.shape ()[1 ];
175+ std::vector<int32> char_dist_buffer;
176+ if (use_chardiff) {
177+ char_dist_buffer.resize (100 );
178+ }
120179
121180 std::vector<std::tuple<int , int > > bestpath;
122181 int i = numr - 1 , j = numc - 1 ;
@@ -138,8 +197,13 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
138197 double left_trans_cost = 1.0 ;
139198 double diag_trans_cost;
140199 if (use_chardiff) {
200+ int alen = a.size ();
201+ int blen = b.size ();
202+ if (alen >= 50 || blen >= 50 ) {
203+ throw std::runtime_error (" Word is too long! Increase buffer" );
204+ }
141205 diag_trans_cost =
142- levdistance ( a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
206+ calc_edit_distance_fast (char_dist_buffer. data (), a.data (), b.data (), a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
143207 } else {
144208 diag_trans_cost = a == b ? 0 . : 1 .;
145209 }
@@ -266,6 +330,12 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
266330 if (M1 - 1 != words_a.Size () || N1 - 1 != words_b.Size ()) throw std::runtime_error (" Sizes do not match!" );
267331 auto buf = array.request ();
268332 double * ptr = (double *) buf.ptr ;
333+
334+ std::vector<int32> char_dist_buffer;
335+ if (use_chardist) {
336+ char_dist_buffer.resize (100 );
337+ }
338+
269339 ptr[0 ] = 0 ;
270340 for (int32 i = 1 ; i < M1; i++) ptr[i*N1] = ptr[(i-1 )*N1] + 1 ;
271341 for (int32 j = 1 ; j < N1; j++) ptr[j] = ptr[j-1 ] + 1 ;
@@ -275,8 +345,13 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
275345 if (use_chardist) {
276346 const std::string_view a = words_a[i-1 ];
277347 const std::string_view b = words_b[j-1 ];
278- transition_cost = levdistance (a.data (), b.data (),
279- a.size (), b.size ()) / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
348+ int alen = a.size ();
349+ int blen = b.size ();
350+ if (alen >= 50 || blen >= 50 ) {
351+ throw std::runtime_error (" Word is too long! Increase buffer" );
352+ }
353+ transition_cost = calc_edit_distance_fast (char_dist_buffer.data (), a.data (), b.data (), a.size (), b.size ())
354+ / static_cast <double >(std::max (a.size (), b.size ())) * 1.5 ;
280355 } else {
281356 transition_cost = words_a[i-1 ] == words_b[j-1 ] ? 0 . : 1 .;
282357 }
@@ -369,5 +444,6 @@ PYBIND11_MODULE(texterrors_align,m) {
369444 m.def (" lev_distance" , lev_distance<int >);
370445 m.def (" lev_distance" , lev_distance<char >);
371446 m.def (" lev_distance_str" , &lev_distance_str);
447+ m.def (" calc_edit_distance_fast_str" , &calc_edit_distance_fast_str);
372448 init_stringvector (m);
373449}
0 commit comments