Skip to content

Commit 646fb8f

Browse files
committed
State with options for using StringVector or fast editdistance
1 parent 6172fdd commit 646fb8f

File tree

5 files changed

+195
-31
lines changed

5 files changed

+195
-31
lines changed

libs/stringvector.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ StringVector::StringVector(const vector<std::string>& words) {
2727
current_index_ = 0;
2828
}
2929

30-
const int StringVector::Size() const {
30+
const int StringVector::size() const {
3131
return wordend_index_.size();
3232
}
3333

3434
const std::string_view StringVector::operator[](const int i) const {
35-
if (i < 0 || i >= Size()) {
35+
if (i < 0 || i >= size()) {
3636
throw std::runtime_error("Invalid index");
3737
}
3838
int start_index = 0;
@@ -49,15 +49,15 @@ StringVector StringVector::iter() {
4949
}
5050

5151
const std::string_view StringVector::next() {
52-
if (current_index_ == Size()) {
52+
if (current_index_ == size()) {
5353
throw pybind11::stop_iteration();
5454
}
5555
return (*this)[current_index_++];
5656
}
5757

5858
std::string StringVector::Str() const {
5959
std::string repr = "";
60-
for (int i = 0; i < Size(); i++) {
60+
for (int i = 0; i < size(); i++) {
6161
repr += std::string{(*this)[i]} + " ";
6262
}
6363
return repr;
@@ -69,8 +69,8 @@ StringVector::~StringVector() {}
6969
void init_stringvector(py::module &m) {
7070
py::class_<StringVector>(m, "StringVector")
7171
.def(py::init<const py::list&>())
72-
.def("size", &StringVector::Size)
73-
.def("__len__", &StringVector::Size)
72+
.def("size", &StringVector::size)
73+
.def("__len__", &StringVector::size)
7474
.def("__getitem__", &StringVector::operator[])
7575
.def("__iter__", &StringVector::iter)
7676
.def("__next__", &StringVector::next)

libs/stringvector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class StringVector {
1616
StringVector(const vector<std::string>& words);
1717
~StringVector();
1818

19-
const int Size() const;
19+
const int size() const;
2020
const std::string_view operator[](const int i) const;
2121
StringVector iter();
2222
const std::string_view next();

libs/texterrors_align.cc

Lines changed: 147 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ enum direction{diag, move_left, up};
168168

169169
std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
170170
const StringVector& words_a,
171-
const StringVector& words_b, const bool use_chardiff) {
171+
const StringVector& words_b, const bool use_chardiff, const bool use_fast_edit_distance=true) {
172172
auto buf = array.request();
173173
double* cost_mat = (double*) buf.ptr;
174174
int32_t numr = array.shape()[0], numc = array.shape()[1];
@@ -202,8 +202,89 @@ std::vector<std::tuple<int32, int32> > get_best_path(py::array_t<double> array,
202202
if (alen >= 50 || blen >= 50) {
203203
throw std::runtime_error("Word is too long! Increase buffer");
204204
}
205-
diag_trans_cost =
205+
if (use_fast_edit_distance) {
206+
diag_trans_cost =
206207
calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
208+
} else {
209+
diag_trans_cost =
210+
levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
211+
}
212+
} else {
213+
diag_trans_cost = a == b ? 0. : 1.;
214+
}
215+
216+
if (isclose(diagc + diag_trans_cost, current_cost)) {
217+
direc = diag;
218+
} else if (isclose(upc + up_trans_cost, current_cost)) {
219+
direc = up;
220+
} else if (isclose(leftc + left_trans_cost, current_cost)) {
221+
direc = move_left;
222+
} else {
223+
std::cout << a <<" "<<b<<" "<<i<<" "<<j<<" trans "<<diag_trans_cost<<" "<<left_trans_cost<<" "<<up_trans_cost<<" costs "<<current_cost<<" "<<diagc<<" "<<leftc<<" "<<upc <<std::endl;
224+
std::cout << (diag_trans_cost + diagc == current_cost) <<std::endl;
225+
std::cout << diag_trans_cost + diagc <<" "<<current_cost <<std::endl;
226+
throw std::runtime_error("Should not be possible !");
227+
}
228+
}
229+
230+
if (direc == up) {
231+
i--;
232+
bestpath.emplace_back(i, -1); // -1 means null token
233+
} else if (direc == move_left) {
234+
j--;
235+
bestpath.emplace_back(-1, j);
236+
} else if (direc == diag) {
237+
i--, j--;
238+
bestpath.emplace_back(i, j);
239+
}
240+
}
241+
return bestpath;
242+
}
243+
244+
245+
std::vector<std::tuple<int32, int32> > get_best_path_lists(py::array_t<double> array,
246+
const std::vector<std::string>& words_a,
247+
const std::vector<std::string>& words_b, const bool use_chardiff, const bool use_fast_edit_distance=true) {
248+
auto buf = array.request();
249+
double* cost_mat = (double*) buf.ptr;
250+
int32_t numr = array.shape()[0], numc = array.shape()[1];
251+
std::vector<int32> char_dist_buffer;
252+
if (use_chardiff) {
253+
char_dist_buffer.resize(100);
254+
}
255+
256+
std::vector<std::tuple<int, int> > bestpath;
257+
int i = numr - 1, j = numc - 1;
258+
while (i != 0 || j != 0) {
259+
double upc, leftc, diagc;
260+
direction direc;
261+
if (i == 0) {
262+
direc = move_left;
263+
} else if (j == 0) {
264+
direc = up;
265+
} else {
266+
float current_cost = cost_mat[i * numc + j];
267+
upc = cost_mat[(i-1) * numc + j];
268+
leftc = cost_mat[i * numc + j - 1];
269+
diagc = cost_mat[(i-1) * numc + j - 1];
270+
const std::string& a = words_a[i-1];
271+
const std::string& b = words_b[j-1];
272+
double up_trans_cost = 1.0;
273+
double left_trans_cost = 1.0;
274+
double diag_trans_cost;
275+
if (use_chardiff) {
276+
int alen = a.size();
277+
int blen = b.size();
278+
if (alen >= 50 || blen >= 50) {
279+
throw std::runtime_error("Word is too long! Increase buffer");
280+
}
281+
if (use_fast_edit_distance) {
282+
diag_trans_cost =
283+
calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
284+
} else {
285+
diag_trans_cost =
286+
levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
287+
}
207288
} else {
208289
diag_trans_cost = a == b ? 0. : 1.;
209290
}
@@ -322,12 +403,12 @@ void get_best_path_ctm(py::array_t<double> array, py::list& bestpath_lst, std::v
322403

323404

324405
int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
325-
const StringVector& words_b, const bool use_chardist) {
406+
const StringVector& words_b, const bool use_chardist, const bool use_fast_edit_distance=true) {
326407
if ( array.ndim() != 2 )
327408
throw std::runtime_error("Input should be 2-D NumPy array");
328409

329410
int M1 = array.shape()[0], N1 = array.shape()[1];
330-
if (M1 - 1 != words_a.Size() || N1 - 1 != words_b.Size()) throw std::runtime_error("Sizes do not match!");
411+
if (M1 - 1 != words_a.size() || N1 - 1 != words_b.size()) throw std::runtime_error("Sizes do not match!");
331412
auto buf = array.request();
332413
double* ptr = (double*) buf.ptr;
333414

@@ -350,8 +431,65 @@ int calc_sum_cost(py::array_t<double> array, const StringVector& words_a,
350431
if (alen >= 50 || blen >= 50) {
351432
throw std::runtime_error("Word is too long! Increase buffer");
352433
}
353-
transition_cost = calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size())
354-
/ static_cast<double>(std::max(a.size(), b.size())) * 1.5;
434+
if (use_fast_edit_distance) {
435+
transition_cost =
436+
calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
437+
} else {
438+
transition_cost =
439+
levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
440+
}
441+
} else {
442+
transition_cost = words_a[i-1] == words_b[j-1] ? 0. : 1.;
443+
}
444+
445+
double upc = ptr[(i-1) * N1 + j] + 1.;
446+
double leftc = ptr[i * N1 + j - 1] + 1.;
447+
double diagc = ptr[(i-1) * N1 + j - 1] + transition_cost;
448+
double sum = std::min(upc, std::min(leftc, diagc));
449+
ptr[i * N1 + j] = sum;
450+
}
451+
}
452+
return ptr[M1*N1 - 1];
453+
}
454+
455+
456+
457+
int calc_sum_cost_lists(py::array_t<double> array, const std::vector<std::string>& words_a,
458+
const std::vector<std::string>& words_b, const bool use_chardist, const bool use_fast_edit_distance=true) {
459+
if ( array.ndim() != 2 )
460+
throw std::runtime_error("Input should be 2-D NumPy array");
461+
462+
int M1 = array.shape()[0], N1 = array.shape()[1];
463+
if (M1 - 1 != words_a.size() || N1 - 1 != words_b.size()) throw std::runtime_error("Sizes do not match!");
464+
auto buf = array.request();
465+
double* ptr = (double*) buf.ptr;
466+
467+
std::vector<int32> char_dist_buffer;
468+
if (use_chardist) {
469+
char_dist_buffer.resize(100);
470+
}
471+
472+
ptr[0] = 0;
473+
for (int32 i = 1; i < M1; i++) ptr[i*N1] = ptr[(i-1)*N1] + 1;
474+
for (int32 j = 1; j < N1; j++) ptr[j] = ptr[j-1] + 1;
475+
for(int32 i = 1; i < M1; i++) {
476+
for(int32 j = 1; j < N1; j++) {
477+
double transition_cost;
478+
if (use_chardist) {
479+
const std::string& a = words_a[i-1];
480+
const std::string& b = words_b[j-1];
481+
int alen = a.size();
482+
int blen = b.size();
483+
if (alen >= 50 || blen >= 50) {
484+
throw std::runtime_error("Word is too long! Increase buffer");
485+
}
486+
if (use_fast_edit_distance) {
487+
transition_cost =
488+
calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
489+
} else {
490+
transition_cost =
491+
levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast<double>(std::max(a.size(), b.size())) * 1.5;
492+
}
355493
} else {
356494
transition_cost = words_a[i-1] == words_b[j-1] ? 0. : 1.;
357495
}
@@ -374,7 +512,7 @@ int calc_sum_cost_ctm(py::array_t<double> array, std::vector<std::string>& texta
374512
throw std::runtime_error("Input should be 2-D NumPy array");
375513

376514
int M = array.shape()[0], N = array.shape()[1];
377-
if (M != texta.size() || N != textb.size()) throw std::runtime_error("Sizes do not match!");
515+
if (M != texta.size() || N != textb.size()) throw std::runtime_error(" s do not match!");
378516
auto buf = array.request();
379517
double* ptr = (double*) buf.ptr;
380518
// std::cout << "STARTING"<<std::endl;
@@ -438,9 +576,11 @@ void init_stringvector(py::module_ &m);
438576
PYBIND11_MODULE(texterrors_align,m) {
439577
m.doc() = "pybind11 plugin";
440578
m.def("calc_sum_cost", &calc_sum_cost, "Calculate summed cost matrix");
579+
m.def("calc_sum_cost_lists", &calc_sum_cost_lists, "Calculate summed cost matrix");
441580
m.def("calc_sum_cost_ctm", &calc_sum_cost_ctm, "Calculate summed cost matrix");
442581
m.def("get_best_path", &get_best_path, "get_best_path");
443582
m.def("get_best_path_ctm", &get_best_path_ctm, "get_best_path_ctm");
583+
m.def("get_best_path_lists", &get_best_path_lists, "get_best_path_lists");
444584
m.def("lev_distance", lev_distance<int>);
445585
m.def("lev_distance", lev_distance<char>);
446586
m.def("lev_distance_str", &lev_distance_str);

tests/test_functions.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from texterrors.texterrors import StringVector
1212
from dataclasses import dataclass
1313
import difflib
14+
from loguru import logger
1415

1516
logger.remove()
1617
logger.add(sys.stderr, level="INFO")
@@ -347,19 +348,25 @@ def test_cli_basic():
347348

348349

349350
def test_speed():
350-
ref = create_inp(open('tests/reftext').read().splitlines())
351-
hyp = create_inp(open('tests/hyptext').read().splitlines())
352-
# import cProfile
353-
# pr = cProfile.Profile()
351+
import time
352+
import sys
353+
logger.remove()
354+
logger.add(sys.stdout, level='INFO')
355+
ref = create_inp(open('test-other.ark').read().splitlines())
356+
hyp = create_inp(open('test-other-mod.ark').read().splitlines())
357+
import cProfile
358+
pr = cProfile.Profile()
354359

355-
# pr.enable()
360+
pr.enable()
356361
buffer = io.StringIO()
357362
start_time = time.perf_counter()
358363
texterrors.process_output(ref, hyp, fh=buffer, ref_file='ref', hyp_file='hyp',
359364
skip_detailed=True, use_chardiff=True, debug=False)
360365
process_time = time.perf_counter() - start_time
361-
# pr.disable()
362-
# pr.dump_stats('speed.prof')
366+
367+
pr.disable()
368+
pr.dump_stats('speed.prof')
363369

364370
logger.info(f'Processing time for speed test is {process_time}')
365-
assert process_time < 2.
371+
assert process_time < 2.
372+

texterrors/texterrors.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from termcolor import colored
1616

1717
OOV_SYM = '<unk>'
18+
CPP_WORDS_CONTAINER = True
1819

1920

2021
def convert_to_int(lst_a, lst_b, dct):
@@ -50,7 +51,7 @@ def seq_distance(a, b):
5051
len_a = len(a)
5152
len_b = len(b)
5253
summed_cost = np.zeros((len_a + 1, len_b + 1), dtype=np.float64, order="C")
53-
cost = texterrors_align.calc_sum_cost(summed_cost, a, b, False)
54+
cost = texterrors_align.calc_sum_cost(summed_cost, a, b, False, True)
5455
return cost
5556

5657

@@ -61,14 +62,21 @@ def _align_texts(words_a, words_b, use_chardiff, debug, insert_tok):
6162
if debug:
6263
print(words_a)
6364
print(words_b)
64-
cost = texterrors_align.calc_sum_cost(summed_cost, words_a, words_b, use_chardiff)
65+
if CPP_WORDS_CONTAINER:
66+
cost = texterrors_align.calc_sum_cost(summed_cost, words_a, words_b, use_chardiff, True)
67+
else:
68+
cost = texterrors_align.calc_sum_cost_lists(summed_cost, words_a, words_b, use_chardiff, True)
6569

6670
if debug:
6771
np.set_printoptions(linewidth=300)
6872
np.savetxt('summedcost', summed_cost, fmt='%.3f', delimiter='\t')
6973

70-
best_path_reversed = texterrors_align.get_best_path(summed_cost,
71-
words_a, words_b, use_chardiff)
74+
if CPP_WORDS_CONTAINER:
75+
best_path_reversed = texterrors_align.get_best_path(summed_cost,
76+
words_a, words_b, use_chardiff, True)
77+
else:
78+
best_path_reversed = texterrors_align.get_best_path_lists(summed_cost,
79+
words_a, words_b, use_chardiff, True)
7280

7381
aligned_a, aligned_b = [], []
7482
for i, j in reversed(best_path_reversed):
@@ -217,11 +225,15 @@ def read_ref_file(ref_f, isark):
217225
if isark:
218226
utt, *words = line.split()
219227
assert utt not in ref_utts, 'There are repeated utterances in reference file! Exiting'
220-
ref_utts[utt] = Utt(utt, StringVector(words))
228+
if CPP_WORDS_CONTAINER:
229+
words = StringVector(words)
230+
ref_utts[utt] = Utt(utt, words)
221231
else:
222232
words = line.split()
223233
i = str(i)
224-
ref_utts[i] = Utt(i, StringVector(words))
234+
if CPP_WORDS_CONTAINER:
235+
words = StringVector(words)
236+
ref_utts[i] = Utt(i, words)
225237
return ref_utts
226238

227239

@@ -232,14 +244,19 @@ def read_hyp_file(hyp_f, isark, oracle_wer):
232244
if isark:
233245
utt, *words = line.split()
234246
words = [w for w in words if w != OOV_SYM]
247+
if CPP_WORDS_CONTAINER:
248+
words = StringVector(words)
235249
if not oracle_wer:
236-
hyp_utts[utt] = Utt(utt, StringVector(words))
250+
hyp_utts[utt] = Utt(utt, words)
237251
else:
238-
hyp_utts[utt].append(Utt(utt, StringVector(words)))
252+
hyp_utts[utt].append(Utt(utt, words))
239253
else:
240254
words = line.split()
241255
i = str(i)
242-
hyp_utts[i] = Utt(i, StringVector([w for w in words if w != OOV_SYM]))
256+
words = [w for w in words if w != OOV_SYM]
257+
if CPP_WORDS_CONTAINER:
258+
words = StringVector(words)
259+
hyp_utts[i] = Utt(i, words)
243260
return hyp_utts
244261

245262

0 commit comments

Comments
 (0)