|
| 1 | +import argparse |
| 2 | +import glob |
| 3 | +import os |
| 4 | +import re |
| 5 | +from typing import Dict, List, Set, Tuple |
| 6 | + |
| 7 | +import pandas as pd |
| 8 | +from openpyxl import Workbook |
| 9 | +from openpyxl.styles import Alignment, Font, PatternFill |
| 10 | +from openpyxl.utils import get_column_letter |
| 11 | + |
| 12 | +from .config import get_mt_exp_dir |
| 13 | + |
| 14 | + |
| 15 | +def read_group_results( |
| 16 | + file_path: str, |
| 17 | + target_book: str, |
| 18 | + all_books: List[str], |
| 19 | + metrics: List[str], |
| 20 | + key_word: str, |
| 21 | + chap_num: int, |
| 22 | +) -> Tuple[Dict[str, Dict[int, List[float]]], Set[int], int]: |
| 23 | + data = {} |
| 24 | + chapter_groups = set() |
| 25 | + for lang_pair in os.listdir(file_path): |
| 26 | + lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)") |
| 27 | + if not lang_pattern.match(lang_pair): |
| 28 | + continue |
| 29 | + |
| 30 | + data[lang_pair] = {} |
| 31 | + prefix = "+".join(all_books) |
| 32 | + pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$") |
| 33 | + |
| 34 | + for groups in os.listdir(os.path.join(file_path, lang_pair)): |
| 35 | + if m := pattern.match(os.path.basename(groups)): |
| 36 | + folder_path = os.path.join(file_path, lang_pair, os.path.basename(groups)) |
| 37 | + diff_pred_file = glob.glob(os.path.join(folder_path, "diff_predictions*")) |
| 38 | + if diff_pred_file: |
| 39 | + r, chap_num = extract_diff_pred_data(diff_pred_file[0], metrics, target_book, chap_num) |
| 40 | + data[lang_pair][int(m.group(1))] = r |
| 41 | + else: |
| 42 | + data[lang_pair][int(m.group(1))] = {} |
| 43 | + print(folder_path + " has no diff_predictions file.") |
| 44 | + chapter_groups.add(int(m.group(1))) |
| 45 | + chap_num = max(chap_num, int(m.group(1))) |
| 46 | + return data, chapter_groups, chap_num |
| 47 | + |
| 48 | + |
| 49 | +def extract_diff_pred_data( |
| 50 | + filename: str, metrics: List[str], target_book: str, chap_num: int, header_row=5 |
| 51 | +) -> Tuple[Dict[int, List[float]], int]: |
| 52 | + metrics = [m.lower() for m in metrics] |
| 53 | + try: |
| 54 | + df = pd.read_excel(filename, header=header_row) |
| 55 | + except ValueError as e: |
| 56 | + print(f"An error occurs in {filename}") |
| 57 | + print(e) |
| 58 | + return {}, chap_num |
| 59 | + |
| 60 | + df.columns = [col.strip().lower() for col in df.columns] |
| 61 | + |
| 62 | + result = {} |
| 63 | + metric_warning = False |
| 64 | + uncalculated_metric = set() |
| 65 | + for _, row in df.iterrows(): |
| 66 | + vref = row["vref"] |
| 67 | + m = re.match(r"(\d?[A-Z]{2,3}) (\d+)", str(vref)) |
| 68 | + if not m: |
| 69 | + print(f"Invalid VREF format: {str(vref)}") |
| 70 | + continue |
| 71 | + |
| 72 | + book_name, chap = m.groups() |
| 73 | + if book_name != target_book: |
| 74 | + continue |
| 75 | + |
| 76 | + chap_num = max(chap_num, int(chap)) |
| 77 | + values = [] |
| 78 | + for metric in metrics: |
| 79 | + if metric in row: |
| 80 | + values.append(float(row[metric])) |
| 81 | + else: |
| 82 | + metric_warning = True |
| 83 | + uncalculated_metric.add(metric) |
| 84 | + values.append(None) |
| 85 | + |
| 86 | + result[int(chap)] = values |
| 87 | + |
| 88 | + if metric_warning: |
| 89 | + print(f"Warning: {uncalculated_metric} was not calculated in {filename}") |
| 90 | + |
| 91 | + return result, chap_num |
| 92 | + |
| 93 | + |
| 94 | +def flatten_dict(data: Dict, chapter_groups: List[int], metrics: List[str], chap_num: int, baseline={}) -> List[str]: |
| 95 | + rows = [] |
| 96 | + if len(data) > 0: |
| 97 | + for lang_pair in data: |
| 98 | + for chap in range(1, chap_num + 1): |
| 99 | + row = [lang_pair, chap] |
| 100 | + row.extend([None, None, None] * len(metrics) * len(data[lang_pair])) |
| 101 | + row.extend([None] * len(chapter_groups)) |
| 102 | + row.extend([None] * (1 + len(metrics))) |
| 103 | + |
| 104 | + for res_chap in data[lang_pair]: |
| 105 | + if chap in data[lang_pair][res_chap]: |
| 106 | + for m in range(len(metrics)): |
| 107 | + index_m = ( |
| 108 | + 3 + 1 + len(metrics) + chapter_groups.index(res_chap) * (len(metrics) * 3 + 1) + m * 3 |
| 109 | + ) |
| 110 | + row[index_m] = data[lang_pair][res_chap][chap][m] |
| 111 | + if len(baseline) > 0: |
| 112 | + for m in range(len(metrics)): |
| 113 | + row[3 + m] = baseline[lang_pair][chap][m] if lang_pair in baseline else None |
| 114 | + rows.append(row) |
| 115 | + else: |
| 116 | + for lang_pair in baseline: |
| 117 | + for chap in range(1, chap_num + 1): |
| 118 | + row = [lang_pair, chap] |
| 119 | + row.extend([None] * (1 + len(metrics))) |
| 120 | + |
| 121 | + for m in range(len(metrics)): |
| 122 | + row[3 + m] = baseline[lang_pair][chap][m] |
| 123 | + rows.append(row) |
| 124 | + |
| 125 | + return rows |
| 126 | + |
| 127 | + |
| 128 | +def create_xlsx( |
| 129 | + rows: List[str], chapter_groups: List[str], output_path: str, metrics: List[str], chap_num: int |
| 130 | +) -> None: |
| 131 | + wb = Workbook() |
| 132 | + ws = wb.active |
| 133 | + |
| 134 | + num_col = len(metrics) * 3 + 1 |
| 135 | + groups = [("language pair", 1), ("Chapter", 1), ("Baseline", (1 + len(metrics)))] |
| 136 | + for chap in chapter_groups: |
| 137 | + groups.append((chap, num_col)) |
| 138 | + |
| 139 | + col = 1 |
| 140 | + for header, span in groups: |
| 141 | + start_col = get_column_letter(col) |
| 142 | + end_col = get_column_letter(col + span - 1) |
| 143 | + ws.merge_cells(f"{start_col}1:{end_col}1") |
| 144 | + ws.cell(row=1, column=col, value=header) |
| 145 | + col += span |
| 146 | + |
| 147 | + sub_headers = [] |
| 148 | + baseline_headers = [] |
| 149 | + |
| 150 | + for i, metric in enumerate(metrics): |
| 151 | + if i == 0: |
| 152 | + baseline_headers.append("rank") |
| 153 | + sub_headers.append("rank") |
| 154 | + baseline_headers.append(metric) |
| 155 | + sub_headers.append(metric) |
| 156 | + sub_headers.append("diff (prev)") |
| 157 | + sub_headers.append("diff (start)") |
| 158 | + |
| 159 | + for i, baseline_header in enumerate(baseline_headers): |
| 160 | + ws.cell(row=2, column=3 + i, value=baseline_header) |
| 161 | + |
| 162 | + col = 3 + len(metrics) + 1 |
| 163 | + for _ in range(len(groups) - 3): |
| 164 | + for i, sub_header in enumerate(sub_headers): |
| 165 | + ws.cell(row=2, column=col + i, value=sub_header) |
| 166 | + |
| 167 | + col += len(sub_headers) |
| 168 | + |
| 169 | + for row in rows: |
| 170 | + ws.append(row) |
| 171 | + |
| 172 | + for row_idx in [1, 2]: |
| 173 | + for col in range(1, ws.max_column + 1): |
| 174 | + ws.cell(row=row_idx, column=col).font = Font(bold=True) |
| 175 | + ws.cell(row=row_idx, column=col).alignment = Alignment(horizontal="center", vertical="center") |
| 176 | + |
| 177 | + ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1) |
| 178 | + ws.merge_cells(start_row=1, start_column=2, end_row=2, end_column=2) |
| 179 | + ws.cell(row=1, column=1).alignment = Alignment(wrap_text=True, horizontal="center", vertical="center") |
| 180 | + |
| 181 | + cur_lang_pair = 3 |
| 182 | + for row_idx in range(3, ws.max_row + 1): |
| 183 | + if ws.cell(row=row_idx, column=4).value is not None: |
| 184 | + ws.cell(row=row_idx, column=3).value = ( |
| 185 | + f"=RANK.EQ(D{row_idx}, INDEX(D:D, INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX(D:D, \ |
| 186 | + INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)" |
| 187 | + ) |
| 188 | + |
| 189 | + start_col = 3 + len(metrics) + 1 |
| 190 | + end_col = ws.max_column |
| 191 | + |
| 192 | + while start_col < end_col: |
| 193 | + start_col += 1 |
| 194 | + if ws.cell(row=row_idx, column=start_col).value is None: |
| 195 | + for col in range(start_col - 1, ws.max_column + 1): |
| 196 | + ws.cell(row=row_idx, column=col).fill = PatternFill( |
| 197 | + fill_type="solid", start_color="CCCCCC", end_color="CCCCCC" |
| 198 | + ) |
| 199 | + break |
| 200 | + |
| 201 | + col_letter = get_column_letter(start_col) |
| 202 | + ws.cell(row=row_idx, column=start_col - 1).value = ( |
| 203 | + f"=RANK.EQ({col_letter}{row_idx}, INDEX({col_letter}:{col_letter}, \ |
| 204 | + INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX({col_letter}:{col_letter}, \ |
| 205 | + INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)" |
| 206 | + ) |
| 207 | + |
| 208 | + for i in range(1, len(metrics) + 1): |
| 209 | + start_letter = get_column_letter(3 + i) |
| 210 | + |
| 211 | + diff_prev_col = start_col + 1 |
| 212 | + diff_start_col = start_col + 2 |
| 213 | + |
| 214 | + prev_letter = ( |
| 215 | + start_letter |
| 216 | + if diff_prev_col <= 3 + len(metrics) + 1 + 3 * len(metrics) |
| 217 | + else get_column_letter(diff_prev_col - 1 - 1 - 3 * len(metrics)) |
| 218 | + ) |
| 219 | + cur_letter = get_column_letter(diff_prev_col - 1) |
| 220 | + |
| 221 | + ws.cell(row=row_idx, column=diff_prev_col).value = f"={cur_letter}{row_idx}-{prev_letter}{row_idx}" |
| 222 | + ws.cell(row=row_idx, column=diff_start_col).value = f"={cur_letter}{row_idx}-{start_letter}{row_idx}" |
| 223 | + |
| 224 | + start_col += 3 |
| 225 | + |
| 226 | + if ws.cell(row=row_idx, column=1).value != ws.cell(row=cur_lang_pair, column=1).value: |
| 227 | + ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx - 1, end_column=1) |
| 228 | + cur_lang_pair = row_idx |
| 229 | + elif row_idx == ws.max_row: |
| 230 | + ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx, end_column=1) |
| 231 | + |
| 232 | + wb.save(output_path) |
| 233 | + |
| 234 | + |
| 235 | +# Sample command: |
| 236 | +# python -m silnlp.nmt.exp_summary Catapult_Reloaded_Confidences |
| 237 | +# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK |
| 238 | +def main() -> None: |
| 239 | + parser = argparse.ArgumentParser( |
| 240 | + description="Pulling results from a single experiment and/or multiple experiment groups. " |
| 241 | + "A valid experiment should have the following format: " |
| 242 | + "baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment " |
| 243 | + "or " |
| 244 | + "exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups " |
| 245 | + "More information in --exp and --baseline. " |
| 246 | + "Use --exp for multiple experiment groups and --baseline for a single experiment. " |
| 247 | + "At least one --exp or --baseline needs to be specified. " |
| 248 | + ) |
| 249 | + parser.add_argument( |
| 250 | + "--exp", |
| 251 | + type=str, |
| 252 | + help="Experiment folder with progression results. " |
| 253 | + "A valid experiment groups should have the following format: " |
| 254 | + "exp/lang_pair/exp_groups/diff_predictions " |
| 255 | + "where there should be at least one exp_groups that naming in the following format: " |
| 256 | + "*book*+*book*_*key-word*_order_*number*_ch " |
| 257 | + "where *book*+*book*... are the combination of all --trained-books with the last one being --target-book. " |
| 258 | + "More information in --key-word. ", |
| 259 | + ) |
| 260 | + parser.add_argument( |
| 261 | + "--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp " |
| 262 | + ) |
| 263 | + parser.add_argument("--target-book", required=True, type=str.upper, help="Book that is going to be analyzed ") |
| 264 | + parser.add_argument( |
| 265 | + "--metrics", |
| 266 | + nargs="*", |
| 267 | + metavar="metrics", |
| 268 | + default=["chrf3", "confidence"], |
| 269 | + type=str.lower, |
| 270 | + help="Metrics that will be analyzed with ", |
| 271 | + ) |
| 272 | + parser.add_argument( |
| 273 | + "--key-word", |
| 274 | + type=str, |
| 275 | + default="conf", |
| 276 | + help="Key word in the filename for the exp group to distinguish between the experiment purpose. " |
| 277 | + "For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf. " |
| 278 | + "Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard. ", |
| 279 | + ) |
| 280 | + parser.add_argument( |
| 281 | + "--baseline", |
| 282 | + type=str, |
| 283 | + help="A non-progression folder for a single experiment. " |
| 284 | + "A valid single experiment should have the following format: " |
| 285 | + "baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format: " |
| 286 | + "*book*+*book*... as the combination of all --trained-books. " |
| 287 | + "or " |
| 288 | + "baseline/lang_pair/diff_predictions " |
| 289 | + "where the information of --trained-books should have already been indicated in the baseline name. ", |
| 290 | + ) |
| 291 | + args = parser.parse_args() |
| 292 | + |
| 293 | + if not (args.exp or args.baseline): |
| 294 | + parser.error("At least one --exp or --baseline needs to be specified. ") |
| 295 | + |
| 296 | + trained_books = args.trained_books |
| 297 | + target_book = args.target_book |
| 298 | + all_books = trained_books + [target_book] |
| 299 | + metrics = args.metrics |
| 300 | + key_word = args.key_word |
| 301 | + |
| 302 | + chap_num = 0 |
| 303 | + |
| 304 | + multi_group_exp_name = args.exp |
| 305 | + multi_group_exp_dir = get_mt_exp_dir(multi_group_exp_name) if multi_group_exp_name else None |
| 306 | + |
| 307 | + single_group_exp_name = args.baseline |
| 308 | + single_group_exp_dir = get_mt_exp_dir(single_group_exp_name) if single_group_exp_name else None |
| 309 | + |
| 310 | + result_file_name = "+".join(all_books) |
| 311 | + result_dir = multi_group_exp_dir if multi_group_exp_dir else single_group_exp_dir |
| 312 | + os.makedirs(os.path.join(result_dir, "a_result_folder"), exist_ok=True) |
| 313 | + output_path = os.path.join(result_dir, "a_result_folder", f"{result_file_name}.xlsx") |
| 314 | + |
| 315 | + data = {} |
| 316 | + chapter_groups = set() |
| 317 | + if multi_group_exp_dir: |
| 318 | + data, chapter_groups, chap_num = read_group_results( |
| 319 | + multi_group_exp_dir, target_book, all_books, metrics, key_word, chap_num |
| 320 | + ) |
| 321 | + chapter_groups = sorted(chapter_groups) |
| 322 | + |
| 323 | + baseline_data = {} |
| 324 | + if single_group_exp_dir: |
| 325 | + for lang_pair in os.listdir(single_group_exp_dir): |
| 326 | + lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)") |
| 327 | + if not lang_pattern.match(lang_pair): |
| 328 | + continue |
| 329 | + |
| 330 | + baseline_path = os.path.join(single_group_exp_dir, lang_pair) |
| 331 | + baseline_diff_pred = glob.glob(os.path.join(baseline_path, "diff_predictions*")) |
| 332 | + if baseline_diff_pred: |
| 333 | + baseline_data[lang_pair], chap_num = extract_diff_pred_data( |
| 334 | + baseline_diff_pred[0], metrics, target_book, chap_num |
| 335 | + ) |
| 336 | + else: |
| 337 | + print(f"Checking experiments under {baseline_path}...") |
| 338 | + sub_baseline_path = os.path.join(baseline_path, "+".join(trained_books)) |
| 339 | + baseline_diff_pred = glob.glob(os.path.join(sub_baseline_path, "diff_predictions*")) |
| 340 | + if baseline_diff_pred: |
| 341 | + baseline_data[lang_pair], chap_num = extract_diff_pred_data( |
| 342 | + baseline_diff_pred[0], metrics, target_book, chap_num |
| 343 | + ) |
| 344 | + else: |
| 345 | + print(f"Baseline experiment has no diff_predictions file in {sub_baseline_path}") |
| 346 | + |
| 347 | + print("Writing data...") |
| 348 | + rows = flatten_dict(data, chapter_groups, metrics, chap_num, baseline=baseline_data) |
| 349 | + create_xlsx(rows, chapter_groups, output_path, metrics, chap_num) |
| 350 | + print(f"Result is in {output_path}") |
| 351 | + |
| 352 | + |
| 353 | +if __name__ == "__main__": |
| 354 | + main() |
0 commit comments