Skip to content

Commit 79bc431

Browse files
authored
Merge pull request #798 from sillsdev/pull_results
Script for pulling experiments results into excel sheets more efficiently as long as the experiments are structured in a certain format.
2 parents f4ff3b4 + 2e01008 commit 79bc431

File tree

1 file changed

+354
-0
lines changed

1 file changed

+354
-0
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
import argparse
2+
import glob
3+
import os
4+
import re
5+
from typing import Dict, List, Set, Tuple
6+
7+
import pandas as pd
8+
from openpyxl import Workbook
9+
from openpyxl.styles import Alignment, Font, PatternFill
10+
from openpyxl.utils import get_column_letter
11+
12+
from .config import get_mt_exp_dir
13+
14+
15+
def read_group_results(
16+
file_path: str,
17+
target_book: str,
18+
all_books: List[str],
19+
metrics: List[str],
20+
key_word: str,
21+
chap_num: int,
22+
) -> Tuple[Dict[str, Dict[int, List[float]]], Set[int], int]:
23+
data = {}
24+
chapter_groups = set()
25+
for lang_pair in os.listdir(file_path):
26+
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
27+
if not lang_pattern.match(lang_pair):
28+
continue
29+
30+
data[lang_pair] = {}
31+
prefix = "+".join(all_books)
32+
pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$")
33+
34+
for groups in os.listdir(os.path.join(file_path, lang_pair)):
35+
if m := pattern.match(os.path.basename(groups)):
36+
folder_path = os.path.join(file_path, lang_pair, os.path.basename(groups))
37+
diff_pred_file = glob.glob(os.path.join(folder_path, "diff_predictions*"))
38+
if diff_pred_file:
39+
r, chap_num = extract_diff_pred_data(diff_pred_file[0], metrics, target_book, chap_num)
40+
data[lang_pair][int(m.group(1))] = r
41+
else:
42+
data[lang_pair][int(m.group(1))] = {}
43+
print(folder_path + " has no diff_predictions file.")
44+
chapter_groups.add(int(m.group(1)))
45+
chap_num = max(chap_num, int(m.group(1)))
46+
return data, chapter_groups, chap_num
47+
48+
49+
def extract_diff_pred_data(
50+
filename: str, metrics: List[str], target_book: str, chap_num: int, header_row=5
51+
) -> Tuple[Dict[int, List[float]], int]:
52+
metrics = [m.lower() for m in metrics]
53+
try:
54+
df = pd.read_excel(filename, header=header_row)
55+
except ValueError as e:
56+
print(f"An error occurs in {filename}")
57+
print(e)
58+
return {}, chap_num
59+
60+
df.columns = [col.strip().lower() for col in df.columns]
61+
62+
result = {}
63+
metric_warning = False
64+
uncalculated_metric = set()
65+
for _, row in df.iterrows():
66+
vref = row["vref"]
67+
m = re.match(r"(\d?[A-Z]{2,3}) (\d+)", str(vref))
68+
if not m:
69+
print(f"Invalid VREF format: {str(vref)}")
70+
continue
71+
72+
book_name, chap = m.groups()
73+
if book_name != target_book:
74+
continue
75+
76+
chap_num = max(chap_num, int(chap))
77+
values = []
78+
for metric in metrics:
79+
if metric in row:
80+
values.append(float(row[metric]))
81+
else:
82+
metric_warning = True
83+
uncalculated_metric.add(metric)
84+
values.append(None)
85+
86+
result[int(chap)] = values
87+
88+
if metric_warning:
89+
print(f"Warning: {uncalculated_metric} was not calculated in {filename}")
90+
91+
return result, chap_num
92+
93+
94+
def flatten_dict(data: Dict, chapter_groups: List[int], metrics: List[str], chap_num: int, baseline={}) -> List[str]:
95+
rows = []
96+
if len(data) > 0:
97+
for lang_pair in data:
98+
for chap in range(1, chap_num + 1):
99+
row = [lang_pair, chap]
100+
row.extend([None, None, None] * len(metrics) * len(data[lang_pair]))
101+
row.extend([None] * len(chapter_groups))
102+
row.extend([None] * (1 + len(metrics)))
103+
104+
for res_chap in data[lang_pair]:
105+
if chap in data[lang_pair][res_chap]:
106+
for m in range(len(metrics)):
107+
index_m = (
108+
3 + 1 + len(metrics) + chapter_groups.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
109+
)
110+
row[index_m] = data[lang_pair][res_chap][chap][m]
111+
if len(baseline) > 0:
112+
for m in range(len(metrics)):
113+
row[3 + m] = baseline[lang_pair][chap][m] if lang_pair in baseline else None
114+
rows.append(row)
115+
else:
116+
for lang_pair in baseline:
117+
for chap in range(1, chap_num + 1):
118+
row = [lang_pair, chap]
119+
row.extend([None] * (1 + len(metrics)))
120+
121+
for m in range(len(metrics)):
122+
row[3 + m] = baseline[lang_pair][chap][m]
123+
rows.append(row)
124+
125+
return rows
126+
127+
128+
def create_xlsx(
129+
rows: List[str], chapter_groups: List[str], output_path: str, metrics: List[str], chap_num: int
130+
) -> None:
131+
wb = Workbook()
132+
ws = wb.active
133+
134+
num_col = len(metrics) * 3 + 1
135+
groups = [("language pair", 1), ("Chapter", 1), ("Baseline", (1 + len(metrics)))]
136+
for chap in chapter_groups:
137+
groups.append((chap, num_col))
138+
139+
col = 1
140+
for header, span in groups:
141+
start_col = get_column_letter(col)
142+
end_col = get_column_letter(col + span - 1)
143+
ws.merge_cells(f"{start_col}1:{end_col}1")
144+
ws.cell(row=1, column=col, value=header)
145+
col += span
146+
147+
sub_headers = []
148+
baseline_headers = []
149+
150+
for i, metric in enumerate(metrics):
151+
if i == 0:
152+
baseline_headers.append("rank")
153+
sub_headers.append("rank")
154+
baseline_headers.append(metric)
155+
sub_headers.append(metric)
156+
sub_headers.append("diff (prev)")
157+
sub_headers.append("diff (start)")
158+
159+
for i, baseline_header in enumerate(baseline_headers):
160+
ws.cell(row=2, column=3 + i, value=baseline_header)
161+
162+
col = 3 + len(metrics) + 1
163+
for _ in range(len(groups) - 3):
164+
for i, sub_header in enumerate(sub_headers):
165+
ws.cell(row=2, column=col + i, value=sub_header)
166+
167+
col += len(sub_headers)
168+
169+
for row in rows:
170+
ws.append(row)
171+
172+
for row_idx in [1, 2]:
173+
for col in range(1, ws.max_column + 1):
174+
ws.cell(row=row_idx, column=col).font = Font(bold=True)
175+
ws.cell(row=row_idx, column=col).alignment = Alignment(horizontal="center", vertical="center")
176+
177+
ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
178+
ws.merge_cells(start_row=1, start_column=2, end_row=2, end_column=2)
179+
ws.cell(row=1, column=1).alignment = Alignment(wrap_text=True, horizontal="center", vertical="center")
180+
181+
cur_lang_pair = 3
182+
for row_idx in range(3, ws.max_row + 1):
183+
if ws.cell(row=row_idx, column=4).value is not None:
184+
ws.cell(row=row_idx, column=3).value = (
185+
f"=RANK.EQ(D{row_idx}, INDEX(D:D, INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX(D:D, \
186+
INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)"
187+
)
188+
189+
start_col = 3 + len(metrics) + 1
190+
end_col = ws.max_column
191+
192+
while start_col < end_col:
193+
start_col += 1
194+
if ws.cell(row=row_idx, column=start_col).value is None:
195+
for col in range(start_col - 1, ws.max_column + 1):
196+
ws.cell(row=row_idx, column=col).fill = PatternFill(
197+
fill_type="solid", start_color="CCCCCC", end_color="CCCCCC"
198+
)
199+
break
200+
201+
col_letter = get_column_letter(start_col)
202+
ws.cell(row=row_idx, column=start_col - 1).value = (
203+
f"=RANK.EQ({col_letter}{row_idx}, INDEX({col_letter}:{col_letter}, \
204+
INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX({col_letter}:{col_letter}, \
205+
INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)"
206+
)
207+
208+
for i in range(1, len(metrics) + 1):
209+
start_letter = get_column_letter(3 + i)
210+
211+
diff_prev_col = start_col + 1
212+
diff_start_col = start_col + 2
213+
214+
prev_letter = (
215+
start_letter
216+
if diff_prev_col <= 3 + len(metrics) + 1 + 3 * len(metrics)
217+
else get_column_letter(diff_prev_col - 1 - 1 - 3 * len(metrics))
218+
)
219+
cur_letter = get_column_letter(diff_prev_col - 1)
220+
221+
ws.cell(row=row_idx, column=diff_prev_col).value = f"={cur_letter}{row_idx}-{prev_letter}{row_idx}"
222+
ws.cell(row=row_idx, column=diff_start_col).value = f"={cur_letter}{row_idx}-{start_letter}{row_idx}"
223+
224+
start_col += 3
225+
226+
if ws.cell(row=row_idx, column=1).value != ws.cell(row=cur_lang_pair, column=1).value:
227+
ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx - 1, end_column=1)
228+
cur_lang_pair = row_idx
229+
elif row_idx == ws.max_row:
230+
ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx, end_column=1)
231+
232+
wb.save(output_path)
233+
234+
235+
# Sample command:
236+
# python -m silnlp.nmt.exp_summary Catapult_Reloaded_Confidences
237+
# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK
238+
def main() -> None:
239+
parser = argparse.ArgumentParser(
240+
description="Pulling results from a single experiment and/or multiple experiment groups. "
241+
"A valid experiment should have the following format: "
242+
"baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment "
243+
"or "
244+
"exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups "
245+
"More information in --exp and --baseline. "
246+
"Use --exp for multiple experiment groups and --baseline for a single experiment. "
247+
"At least one --exp or --baseline needs to be specified. "
248+
)
249+
parser.add_argument(
250+
"--exp",
251+
type=str,
252+
help="Experiment folder with progression results. "
253+
"A valid experiment groups should have the following format: "
254+
"exp/lang_pair/exp_groups/diff_predictions "
255+
"where there should be at least one exp_groups that naming in the following format: "
256+
"*book*+*book*_*key-word*_order_*number*_ch "
257+
"where *book*+*book*... are the combination of all --trained-books with the last one being --target-book. "
258+
"More information in --key-word. ",
259+
)
260+
parser.add_argument(
261+
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp "
262+
)
263+
parser.add_argument("--target-book", required=True, type=str.upper, help="Book that is going to be analyzed ")
264+
parser.add_argument(
265+
"--metrics",
266+
nargs="*",
267+
metavar="metrics",
268+
default=["chrf3", "confidence"],
269+
type=str.lower,
270+
help="Metrics that will be analyzed with ",
271+
)
272+
parser.add_argument(
273+
"--key-word",
274+
type=str,
275+
default="conf",
276+
help="Key word in the filename for the exp group to distinguish between the experiment purpose. "
277+
"For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf. "
278+
"Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard. ",
279+
)
280+
parser.add_argument(
281+
"--baseline",
282+
type=str,
283+
help="A non-progression folder for a single experiment. "
284+
"A valid single experiment should have the following format: "
285+
"baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format: "
286+
"*book*+*book*... as the combination of all --trained-books. "
287+
"or "
288+
"baseline/lang_pair/diff_predictions "
289+
"where the information of --trained-books should have already been indicated in the baseline name. ",
290+
)
291+
args = parser.parse_args()
292+
293+
if not (args.exp or args.baseline):
294+
parser.error("At least one --exp or --baseline needs to be specified. ")
295+
296+
trained_books = args.trained_books
297+
target_book = args.target_book
298+
all_books = trained_books + [target_book]
299+
metrics = args.metrics
300+
key_word = args.key_word
301+
302+
chap_num = 0
303+
304+
multi_group_exp_name = args.exp
305+
multi_group_exp_dir = get_mt_exp_dir(multi_group_exp_name) if multi_group_exp_name else None
306+
307+
single_group_exp_name = args.baseline
308+
single_group_exp_dir = get_mt_exp_dir(single_group_exp_name) if single_group_exp_name else None
309+
310+
result_file_name = "+".join(all_books)
311+
result_dir = multi_group_exp_dir if multi_group_exp_dir else single_group_exp_dir
312+
os.makedirs(os.path.join(result_dir, "a_result_folder"), exist_ok=True)
313+
output_path = os.path.join(result_dir, "a_result_folder", f"{result_file_name}.xlsx")
314+
315+
data = {}
316+
chapter_groups = set()
317+
if multi_group_exp_dir:
318+
data, chapter_groups, chap_num = read_group_results(
319+
multi_group_exp_dir, target_book, all_books, metrics, key_word, chap_num
320+
)
321+
chapter_groups = sorted(chapter_groups)
322+
323+
baseline_data = {}
324+
if single_group_exp_dir:
325+
for lang_pair in os.listdir(single_group_exp_dir):
326+
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
327+
if not lang_pattern.match(lang_pair):
328+
continue
329+
330+
baseline_path = os.path.join(single_group_exp_dir, lang_pair)
331+
baseline_diff_pred = glob.glob(os.path.join(baseline_path, "diff_predictions*"))
332+
if baseline_diff_pred:
333+
baseline_data[lang_pair], chap_num = extract_diff_pred_data(
334+
baseline_diff_pred[0], metrics, target_book, chap_num
335+
)
336+
else:
337+
print(f"Checking experiments under {baseline_path}...")
338+
sub_baseline_path = os.path.join(baseline_path, "+".join(trained_books))
339+
baseline_diff_pred = glob.glob(os.path.join(sub_baseline_path, "diff_predictions*"))
340+
if baseline_diff_pred:
341+
baseline_data[lang_pair], chap_num = extract_diff_pred_data(
342+
baseline_diff_pred[0], metrics, target_book, chap_num
343+
)
344+
else:
345+
print(f"Baseline experiment has no diff_predictions file in {sub_baseline_path}")
346+
347+
print("Writing data...")
348+
rows = flatten_dict(data, chapter_groups, metrics, chap_num, baseline=baseline_data)
349+
create_xlsx(rows, chapter_groups, output_path, metrics, chap_num)
350+
print(f"Result is in {output_path}")
351+
352+
353+
if __name__ == "__main__":
354+
main()

0 commit comments

Comments
 (0)