Skip to content

Commit 6172fdd

Browse files
committed
is bug but faster
1 parent 502f2dd commit 6172fdd

File tree

6 files changed

+124
-20
lines changed

6 files changed

+124
-20
lines changed

libs/stringvector.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ const std::string_view StringVector::next() {
5555
return (*this)[current_index_++];
5656
}
5757

58+
std::string StringVector::Str() const {
59+
std::string repr = "";
60+
for (int i = 0; i < Size(); i++) {
61+
repr += std::string{(*this)[i]} + " ";
62+
}
63+
return repr;
64+
}
5865

5966
StringVector::~StringVector() {}
6067

@@ -66,5 +73,6 @@ void init_stringvector(py::module &m) {
6673
.def("__len__", &StringVector::Size)
6774
.def("__getitem__", &StringVector::operator[])
6875
.def("__iter__", &StringVector::iter)
69-
.def("__next__", &StringVector::next);
76+
.def("__next__", &StringVector::next)
77+
.def("__str__", &StringVector::Str);
7078
}

libs/stringvector.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class StringVector {
2020
const std::string_view operator[](const int i) const;
2121
StringVector iter();
2222
const std::string_view next();
23+
std::string Str() const;
2324

2425
std::string data_;
2526
std::vector<int> wordend_index_;

libs/texterrors_align.cc

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,56 @@ struct Pair {
2727
int16_t j;
2828
};
2929

30+
31+
int calc_edit_distance_fast(int32* cost_mat, const char* a, const char* b,
32+
const int32 M, const int32 N) {
33+
int row_length = N+1;
34+
// std::cout << "STARTING M="<< M<< " N="<<N<<std::endl;
35+
for (int32 i = 0; i <= M; ++i) {
36+
for (int32 j = 0; j <= N; ++j) {
37+
38+
if (i == 0 && j == 0) {
39+
cost_mat[0] = 0;
40+
continue;
41+
}
42+
if (i == 0) {
43+
cost_mat[j] = cost_mat[j - 1] + 1;
44+
continue;
45+
}
46+
if (j == 0) {
47+
cost_mat[row_length] = cost_mat[0] + 1;
48+
continue;
49+
}
50+
int32 transition_cost = a[i-1] == b[j-1] ? 0 : 1;
51+
52+
int32 upc = cost_mat[j] + 1;
53+
int32 leftc = cost_mat[row_length + j - 1] + 1;
54+
int32 diagc = cost_mat[j - 1] + transition_cost;
55+
int32 cost = std::min(upc, std::min(leftc, diagc) );
56+
57+
cost_mat[row_length + j] = cost;
58+
cost_mat[j - 1] = cost_mat[row_length + j - 1]; // copying result up after use
59+
}
60+
if (i > 0) {
61+
cost_mat[N] = cost_mat[row_length + N];
62+
}
63+
64+
// std::cout << "row "<<i;
65+
// for (int32 j = 0; j <= N; ++j) {
66+
// std::cout << " "<<cost_mat[j];
67+
// }
68+
// std::cout << std::endl;
69+
}
70+
// std::cout << "last row";
71+
// for (int32 j = 0; j <= N; ++j) {
72+
// std::cout <<" "<<cost_mat[row_length + j];
73+
// }
74+
// std::cout << std::endl;
75+
76+
return cost_mat[row_length - 1];
77+
}
78+
79+
3080
template <class T>
3181
void create_lev_cost_mat(int32* cost_mat, const T* a, const T* b,
3282
const int32 M, const int32 N) {
@@ -109,6 +159,11 @@ int lev_distance_str(std::string a, std::string b) {
109159
return levdistance(a.data(), b.data(), a.size(), b.size());
110160
}
111161

162+
int calc_edit_distance_fast_str(std::string a, std::string b) {
163+
std::vector<int> buffer(a.size() + b.size() + 2);
164+
return calc_edit_distance_fast(buffer.data(), a.data(), b.data(), a.size(), b.size());
165+
}
166+
112167
enum direction{diag, move_left, up};
113168

114169
std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
@@ -117,6 +172,10 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
117172
auto buf = array.request();
118173
double* cost_mat = (double*) buf.ptr;
119174
int32_t numr = array.shape()[0], numc = array.shape()[1];
175+
std::vector<int32> char_dist_buffer;
176+
if (use_chardiff) {
177+
char_dist_buffer.resize(100);
178+
}
120179

121180
std::vector<std::tuple<int, int> > bestpath;
122181
int i = numr - 1, j = numc - 1;
@@ -138,8 +197,13 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
138197
double left_trans_cost = 1.0;
139198
double diag_trans_cost;
140199
if (use_chardiff) {
200+
int alen = a.size();
201+
int blen = b.size();
202+
if (alen >= 50 || blen >= 50) {
203+
throw std::runtime_error("Word is too long! Increase buffer");
204+
}
141205
diag_trans_cost =
142-
levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
206+
calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
143207
} else {
144208
diag_trans_cost = a == b ? 0. : 1.;
145209
}
@@ -266,6 +330,12 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
266330
if (M1 - 1 != words_a.Size() || N1 - 1 != words_b.Size()) throw std::runtime_error("Sizes do not match!");
267331
auto buf = array.request();
268332
double* ptr = (double*) buf.ptr;
333+
334+
std::vector<int32> char_dist_buffer;
335+
if (use_chardist) {
336+
char_dist_buffer.resize(100);
337+
}
338+
269339
ptr[0] = 0;
270340
for (int32 i = 1; i < M1; i++) ptr[i*N1] = ptr[(i-1)*N1] + 1;
271341
for (int32 j = 1; j < N1; j++) ptr[j] = ptr[j-1] + 1;
@@ -275,8 +345,13 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
275345
if (use_chardist) {
276346
const std::string_view a = words_a[i-1];
277347
const std::string_view b = words_b[j-1];
278-
transition_cost = levdistance(a.data(), b.data(),
279-
a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
348+
int alen = a.size();
349+
int blen = b.size();
350+
if (alen >= 50 || blen >= 50) {
351+
throw std::runtime_error("Word is too long! Increase buffer");
352+
}
353+
transition_cost = calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size())
354+
/ static_cast<double>(std::max(a.size(), b.size())) * 1.5;
280355
} else {
281356
transition_cost = words_a[i-1] == words_b[j-1] ? 0. : 1.;
282357
}
@@ -369,5 +444,6 @@ PYBIND11_MODULE(texterrors_align,m) {
369444
m.def("lev_distance", lev_distance<int>);
370445
m.def("lev_distance", lev_distance<char>);
371446
m.def("lev_distance_str", &lev_distance_str);
447+
m.def("calc_edit_distance_fast_str", &calc_edit_distance_fast_str);
372448
init_stringvector(m);
373449
}

tests/test_functions.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import os
44
import io
55
import time
6-
import logging
6+
import sys
7+
from loguru import logger
78

89
import Levenshtein as levd
910
from texterrors import texterrors
1011
from texterrors.texterrors import StringVector
1112
from dataclasses import dataclass
1213
import difflib
1314

14-
15-
logger = logging.getLogger(__name__)
15+
logger.remove()
16+
logger.add(sys.stderr, level="INFO")
1617

1718
def show_diff(text1, text2):
1819
# Split the strings into lines to compare them line by line
@@ -38,6 +39,14 @@ def test_levd():
3839
assert d1 == d2, f'{a} {b} {d1} {d2}'
3940

4041

42+
# def test_calc_edit_distance_fast():
43+
# pairs = ['a', '', '', 'a', 'MOZILLA', 'MUSIAL', 'ARE', 'MOZILLA', 'TURNIPS', 'TENTH', 'POSTERS', 'POSTURE']
44+
# for a, b in zip(pairs[:-1:2], pairs[1::2]):
45+
# d1 = texterrors.calc_edit_distance_fast(a, b)
46+
# d2 = levd.distance(a, b)
47+
# assert d1 == d2, f'{a} {b} fasteditdist={d1} ref={d2}'
48+
49+
4150
def calc_wer(ref, b):
4251
cnt = 0
4352
err = 0
@@ -266,7 +275,7 @@ def test_process_output_colored():
266275
hyps = create_inp(hyplines)
267276

268277
buffer = io.StringIO()
269-
texterrors.process_output(refs, hyps, buffer, ref_file='A', hyp_file='B', nocolor=False)
278+
texterrors.process_output(refs, hyps, buffer, ref_file='A', hyp_file='B', nocolor=False, terminal_width=80)
270279
output = buffer.getvalue()
271280
ref = """\"A\" is treated as reference (white and green), \"B\" as hypothesis (white and red).
272281
Per utt details:
@@ -297,8 +306,8 @@ def test_process_output_colored():
297306
es>sie\t1\t2
298307
ja>auch\t1\t1
299308
"""
300-
#print(ref, file=open('ref', 'w'))
301-
#print(output, file=open('output', 'w'))
309+
print(ref, file=open('ref', 'w'))
310+
print(output, file=open('output', 'w'))
302311
assert ref == output
303312

304313

@@ -340,16 +349,17 @@ def test_cli_basic():
340349
def test_speed():
341350
ref = create_inp(open('tests/reftext').read().splitlines())
342351
hyp = create_inp(open('tests/hyptext').read().splitlines())
343-
import cProfile
344-
pr = cProfile.Profile()
352+
# import cProfile
353+
# pr = cProfile.Profile()
345354

346-
pr.enable()
355+
# pr.enable()
347356
buffer = io.StringIO()
348357
start_time = time.perf_counter()
349-
texterrors.process_output(ref, hyp, fh=buffer, ref_file='ref', hyp_file='hyp', skip_detailed=True)
358+
texterrors.process_output(ref, hyp, fh=buffer, ref_file='ref', hyp_file='hyp',
359+
skip_detailed=True, use_chardiff=True, debug=False)
350360
process_time = time.perf_counter() - start_time
351-
pr.disable()
352-
pr.dump_stats('speed.prof')
361+
# pr.disable()
362+
# pr.dump_stats('speed.prof')
353363

354364
logger.info(f'Processing time for speed test is {process_time}')
355365
assert process_time < 2.

texterrors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .texterrors import align_texts, process_lines, lev_distance, get_oov_cer, align_texts_ctm, seq_distance, \
2-
process_output, process_multiple_outputs
2+
process_output, process_multiple_outputs, calc_edit_distance_fast

texterrors/texterrors.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def lev_distance(a, b):
4040
return texterrors_align.lev_distance(a, b)
4141

4242

43+
def calc_edit_distance_fast(a, b):
44+
return texterrors_align.calc_edit_distance_fast_str(a, b)
45+
46+
4347
def seq_distance(a, b):
4448
""" This function is for when a and b have strings as elements (variable length). """
4549
assert isinstance(a, StringVector) and isinstance(b, StringVector), 'Input types should be of type StringVector!'
@@ -53,6 +57,10 @@ def seq_distance(a, b):
5357
def _align_texts(words_a, words_b, use_chardiff, debug, insert_tok):
5458
summed_cost = np.zeros((len(words_a) + 1, len(words_b) + 1), dtype=np.float64,
5559
order="C")
60+
61+
if debug:
62+
print(words_a)
63+
print(words_b)
5664
cost = texterrors_align.calc_sum_cost(summed_cost, words_a, words_b, use_chardiff)
5765

5866
if debug:
@@ -656,10 +664,11 @@ def process_multiple_outputs(ref_utts, hypa_utts, hypb_utts, fh, num_top_errors,
656664
def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_top_errors=10, oov_set=None, debug=False,
657665
use_chardiff=True, isctm=False, skip_detailed=False,
658666
keywords=None, utt_group_map=None, oracle_wer=False,
659-
freq_sort=False, nocolor=False, insert_tok='<eps>'):
667+
freq_sort=False, nocolor=False, insert_tok='<eps>', terminal_width=None):
660668

661-
terminal_width, _ = shutil.get_terminal_size()
662-
terminal_width = 120 if terminal_width >= 120 else terminal_width
669+
if terminal_width is None:
670+
terminal_width, _ = shutil.get_terminal_size()
671+
terminal_width = 120 if terminal_width >= 120 else terminal_width
663672

664673
if oov_set is None:
665674
oov_set = set()

0 commit comments

Comments
 (0)