From 81c48d5e2918f47baa47e31d9122e26029c530bb Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Fri, 20 Jun 2025 04:40:09 +0000 Subject: [PATCH 01/10] Add support for multiple postprocessing requests --- silnlp/common/postprocess_draft.py | 108 ++++++++++++++------ silnlp/common/translator.py | 153 ++++++++++++++++++----------- silnlp/nmt/experiment.py | 8 +- silnlp/nmt/translate.py | 44 +++------ 4 files changed, 185 insertions(+), 128 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 2bd847e1..c708800b 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -40,7 +40,7 @@ def get_paths_from_exp(config: Config) -> Tuple[Path, Path]: translate_config = yaml.safe_load(file)["translate"][0] src_project = translate_config.get("src_project", next(iter(config.src_projects))) books = translate_config["books"] - book = books[0] if isinstance(books, list) else books.split(";")[0] # TODO: handle partial book translation + book = books[0][:3] if isinstance(books, list) else books.split(";")[0][:3] book_num = book_id_to_number(book) ckpt = translate_config.get("checkpoint", "last") @@ -89,9 +89,7 @@ def get_sentences( def main() -> None: - parser = argparse.ArgumentParser( - description="Applies draft postprocessing steps to a draft. Can be used with no postprocessing options to create a base draft." - ) + parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.") parser.add_argument( "--experiment", default=None, @@ -173,6 +171,27 @@ def main() -> None: src_path = Path(args.source.replace("\\", "/")) draft_path = Path(args.draft.replace("\\", "/")) + # If no postprocessing options are used, use any postprocessing requests in the experiment's translate config + if args.include_paragraph_markers or args.include_style_markers or args.include_embeds: + postprocess_configs = [ + { + "include_paragraph_markers": args.include_paragraph_markers, + "include_style_markers": args.include_style_markers, + "include_embeds": args.include_embeds, + } + ] + else: + if args.experiment: + LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.") + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: + postprocess_configs = yaml.safe_load(file).get("postprocess", []) + if len(postprocess_configs) == 0: + LOGGER.info("No postprocessing requests found.") + exit() + else: + LOGGER.info("Please use at least one postprocessing option.") + exit() + if str(src_path).startswith(str(get_project_dir(""))): settings = FileParatextProjectSettingsParser(src_path.parent).parse() stylesheet = settings.stylesheet @@ -198,36 +217,61 @@ def main() -> None: f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}" ) - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE if args.include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_style_markers else UpdateUsfmMarkerBehavior.STRIP - embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_embeds else UpdateUsfmMarkerBehavior.STRIP - - update_block_handlers = [] - if args.include_paragraph_markers or args.include_style_markers: - update_block_handlers.append(construct_place_markers_handler(src_refs, src_sents, draft_sents)) - - with src_path.open(encoding=encoding) as f: - usfm = f.read() - handler = UpdateUsfmParserHandler( - rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)], - id_text=book, - text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, - ) - parse_usfm(usfm, handler) - usfm_out = handler.get_usfm() + if any( + ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False) + for ppc in postprocess_configs + ): + place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) + + for postprocess_config in postprocess_configs: + update_block_handlers = [] + if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get( + "include_style_markers", False + ): + update_block_handlers.append(place_markers_handler) + + paragraph_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_paragraph_markers", False) + else UpdateUsfmMarkerBehavior.STRIP + ) + style_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_style_markers", False) + else UpdateUsfmMarkerBehavior.STRIP + ) + embed_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_embeds", False) + else UpdateUsfmMarkerBehavior.STRIP + ) + marker_placement_suffix = ( + "_" + + ("p" if postprocess_config.get("include_paragraph_markers", False) else "") + + ("s" if postprocess_config.get("include_style_markers", False) else "") + + ("e" if postprocess_config.get("include_embeds", False) else "") + ) + + with src_path.open(encoding=encoding) as f: + usfm = f.read() + handler = UpdateUsfmParserHandler( + rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)], + id_text=book, + text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + parse_usfm(usfm, handler) + usfm_out = handler.get_usfm() - usfm_out = insert_draft_remarks(usfm_out, draft_remarks) + usfm_out = insert_draft_remarks(usfm_out, draft_remarks) - out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent - out_path = out_dir / f"{draft_path.stem}_postprocessed{draft_path.suffix}" - with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: - f.write(usfm_out) + out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent + out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" + with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: + f.write(usfm_out) if __name__ == "__main__": diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 71acfe5a..9f137352 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -5,7 +5,7 @@ from itertools import groupby from math import exp from pathlib import Path -from typing import Iterable, List, Optional +from typing import Dict, Iterable, List, Optional import docx import nltk @@ -137,9 +137,7 @@ def translate_book( produce_multiple_translations: bool = False, chapters: List[int] = [], trg_project: Optional[str] = None, - include_paragraph_markers: bool = False, - include_style_markers: bool = False, - include_embeds: bool = False, + postprocess_configs: List[Dict[str, bool]] = [], experiment_ckpt_str: str = "", ) -> None: book_path = get_book_path(src_project, book) @@ -156,9 +154,7 @@ def translate_book( produce_multiple_translations, chapters, trg_project, - include_paragraph_markers, - include_style_markers, - include_embeds, + postprocess_configs, experiment_ckpt_str, ) @@ -171,9 +167,7 @@ def translate_usfm( produce_multiple_translations: bool = False, chapters: List[int] = [], trg_project: Optional[str] = None, - include_paragraph_markers: bool = False, - include_style_markers: bool = False, - include_embeds: bool = False, + postprocess_configs: List[Dict[str, bool]] = [], experiment_ckpt_str: str = "", ) -> None: # Create UsfmFileText object for source @@ -226,72 +220,111 @@ def translate_usfm( vrefs.insert(idx, vref) output.insert(idx, [None, None, None, None]) - # Update behaviors - text_behavior = ( - UpdateUsfmTextBehavior.PREFER_NEW if trg_project is not None else UpdateUsfmTextBehavior.STRIP_EXISTING - ) - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE if include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_style_markers else UpdateUsfmMarkerBehavior.STRIP - embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_embeds else UpdateUsfmMarkerBehavior.STRIP + # Base draft + postprocess_configs = [ + {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False} + ] + postprocess_configs draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)] - update_block_handlers = [] - if include_paragraph_markers or include_style_markers: - update_block_handlers.append(construct_place_markers_handler(vrefs, sentences, translated_draft)) - - # Insert translation into the USFM structure of an existing project - # If the target project is not the same as the translated file's original project, - # no verses outside of the ones translated will be overwritten - if trg_project is not None or src_from_project: - dest_updater = FileParatextProjectTextUpdater( - get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) + if any( + ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False) + for ppc in postprocess_configs + ): + place_markers_handler = construct_place_markers_handler(vrefs, sentences, translated_draft) + + for postprocess_config in postprocess_configs: + # Update behaviors + text_behavior = ( + UpdateUsfmTextBehavior.PREFER_NEW + if trg_project is not None + else UpdateUsfmTextBehavior.STRIP_EXISTING ) - usfm_out = dest_updater.update_usfm( - book_id=src_file_text.id, - rows=rows, - text_behavior=text_behavior, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, + paragraph_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_paragraph_markers", False) + else UpdateUsfmMarkerBehavior.STRIP ) - - if usfm_out is None: - raise FileNotFoundError(f"Book {src_file_text.id} does not exist in target project {trg_project}") - else: # Slightly more manual version for updating an individual file - with open(src_file_path, encoding="utf-8-sig") as f: - usfm = f.read() - handler = UpdateUsfmParserHandler( - rows=rows, - id_text=vrefs[0].book, - text_behavior=text_behavior, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, + style_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_style_markers", False) + else UpdateUsfmMarkerBehavior.STRIP + ) + embed_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config.get("include_embeds", False) + else UpdateUsfmMarkerBehavior.STRIP + ) + marker_placement_suffix = ( + "_" + + ("p" if postprocess_config.get("include_paragraph_markers", False) else "") + + ("s" if postprocess_config.get("include_style_markers", False) else "") + + ("e" if postprocess_config.get("include_embeds", False) else "") ) - parse_usfm(usfm, handler) - usfm_out = handler.get_usfm() + marker_placement_suffix = "" if len(marker_placement_suffix) == 1 else marker_placement_suffix + + update_block_handlers = [] + if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get( + "include_style_markers", False + ): + update_block_handlers.append(place_markers_handler) + + # Insert translation into the USFM structure of an existing project + # If the target project is not the same as the translated file's original project, + # no verses outside of the ones translated will be overwritten + if trg_project is not None or src_from_project: + dest_updater = FileParatextProjectTextUpdater( + get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) + ) + usfm_out = dest_updater.update_usfm( + book_id=src_file_text.id, + rows=rows, + text_behavior=text_behavior, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + + if usfm_out is None: + raise FileNotFoundError( + f"Book {src_file_text.id} does not exist in target project {trg_project}" + ) + else: # Slightly more manual version for updating an individual file + with open(src_file_path, encoding="utf-8-sig") as f: + usfm = f.read() + handler = UpdateUsfmParserHandler( + rows=rows, + id_text=vrefs[0].book, + text_behavior=text_behavior, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + parse_usfm(usfm, handler) + usfm_out = handler.get_usfm() + + # Insert draft remark and write to output path + description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" + usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str) + trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + marker_placement_suffix) + if produce_multiple_translations: + trg_draft_file_path = trg_draft_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}") + with trg_draft_file_path.open( + "w", encoding=src_settings.encoding if src_from_project else "utf-8" + ) as f: + f.write(usfm_out) - # Insert draft remark and write to output path - description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str) confidence_scores_suffix = ".confidences.tsv" if produce_multiple_translations: - trg_draft_file_path = trg_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}") confidences_path = trg_file_path.with_suffix( f".{draft_index}{trg_file_path.suffix}{confidence_scores_suffix}" ) else: - trg_draft_file_path = trg_file_path confidences_path = trg_file_path.with_suffix(f"{trg_file_path.suffix}{confidence_scores_suffix}") - with trg_draft_file_path.open("w", encoding=src_settings.encoding if src_from_project else "utf-8") as f: - f.write(usfm_out) with confidences_path.open("w", encoding="utf-8", newline="\n") as confidences_file: confidences_file.write("\t".join(["VRef"] + [f"Token {i}" for i in range(200)]) + "\n") confidences_file.write("\t".join(["Sequence Score"] + [f"Token Score {i}" for i in range(200)]) + "\n") diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 8e44598d..9c6d3d6b 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -95,9 +95,7 @@ def translate(self): config.get("trg_project"), config.get("trg_iso"), self.produce_multiple_translations, - config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False), - config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False), - config.get("include_embeds", False) or config.get("include_inline_elements", False), + translate_configs.get("postprocess", []), ) elif config.get("src_prefix"): translator.translate_text_files( @@ -116,9 +114,7 @@ def translate(self): config.get("src_iso"), config.get("trg_iso"), self.produce_multiple_translations, - config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False), - config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False), - config.get("include_embeds", False) or config.get("include_inline_elements", False), + translate_configs.get("postprocess", []), ) else: raise RuntimeError("A Scripture book, file, or file prefix must be specified for translation.") diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index e066dd21..f14924c7 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from machine.scripture import VerseRef, book_number_to_id, get_chapters @@ -54,9 +54,7 @@ def translate_books( trg_project: Optional[str], trg_iso: Optional[str], produce_multiple_translations: bool = False, - include_paragraph_markers: bool = False, - include_style_markers: bool = False, - include_embeds: bool = False, + postprocess_configs: List[Dict[str, bool]] = [], ): book_nums = get_chapters(books) translator, config, step_str = self._init_translation_task( @@ -112,9 +110,7 @@ def translate_books( produce_multiple_translations, chapters, trg_project, - include_paragraph_markers, - include_style_markers, - include_embeds, + postprocess_configs, experiment_ckpt_str, ) except Exception as e: @@ -178,9 +174,7 @@ def translate_files( src_iso: Optional[str], trg_iso: Optional[str], produce_multiple_translations: bool = False, - include_paragraph_markers: bool = False, - include_style_markers: bool = False, - include_embeds: bool = False, + postprocess_configs: List[Dict[str, bool]] = [], ) -> None: translator, config, step_str = self._init_translation_task( experiment_suffix=f"_{self.checkpoint}_{os.path.basename(src)}", @@ -248,9 +242,7 @@ def translate_files( src_iso, trg_iso, produce_multiple_translations, - include_paragraph_markers=include_paragraph_markers, - include_style_markers=include_style_markers, - include_embeds=include_embeds, + postprocess_configs, experiment_ckpt_str=experiment_ckpt_str, ) @@ -383,12 +375,13 @@ def main() -> None: name=args.experiment, checkpoint=args.checkpoint, clearml_queue=args.clearml_queue, commit=args.commit ) - # For backwards compatibility - if args.preserve_usfm_markers: - args.include_paragraph_markers = True - args.include_style_markers = True - if args.include_inline_elements: - args.include_embeds = True + postprocess_configs = [ + { + "include_paragraph_markers": args.include_paragraph_markers or args.preserve_usfm_markers, + "include_style_markers": args.include_style_markers or args.preserve_usfm_markers, + "include_embeds": args.include_embeds or args.include_inline_elements, + } + ] if len(args.books) > 0: if args.debug: @@ -400,9 +393,7 @@ def main() -> None: args.trg_project, args.trg_iso, args.multiple_translations, - args.include_paragraph_markers, - args.include_style_markers, - args.include_embeds, + postprocess_configs, ) elif args.src_prefix is not None: if args.debug: @@ -428,14 +419,7 @@ def main() -> None: ) exit() translator.translate_files( - args.src, - args.trg, - args.src_iso, - args.trg_iso, - args.multiple_translations, - args.include_paragraph_markers, - args.include_style_markers, - args.include_embeds, + args.src, args.trg, args.src_iso, args.trg_iso, args.multiple_translations, postprocess_configs ) else: raise RuntimeError("A Scripture book, file, or file prefix must be specified.") From d016313194df8b2201dca6f5787e96137eadb1ab Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Tue, 24 Jun 2025 16:33:28 +0000 Subject: [PATCH 02/10] Use defaultdicts for postprocessing configs --- silnlp/common/postprocess_draft.py | 25 +++++++++++-------------- silnlp/common/translator.py | 29 ++++++++++++----------------- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index c708800b..ca9a0409 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -1,6 +1,7 @@ import argparse import logging import re +from collections import defaultdict from pathlib import Path from typing import List, Tuple @@ -25,7 +26,7 @@ from ..nmt.hugging_face_config import get_best_checkpoint from .paratext import book_file_name_digits, get_book_path, get_project_dir from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler -from .utils import get_mt_exp_dir +from .utils import get_mt_exp_dir, merge_dict LOGGER = logging.getLogger(__package__ + ".postprocess_draft") @@ -185,6 +186,7 @@ def main() -> None: LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.") with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: postprocess_configs = yaml.safe_load(file).get("postprocess", []) + postprocess_configs = [merge_dict(defaultdict(lambda: False), ppc) for ppc in postprocess_configs] if len(postprocess_configs) == 0: LOGGER.info("No postprocessing requests found.") exit() @@ -217,39 +219,34 @@ def main() -> None: f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}" ) - if any( - ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False) - for ppc in postprocess_configs - ): + if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) for postprocess_config in postprocess_configs: update_block_handlers = [] - if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get( - "include_style_markers", False - ): + if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: update_block_handlers.append(place_markers_handler) paragraph_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_paragraph_markers", False) + if postprocess_config["include_paragraph_markers"] else UpdateUsfmMarkerBehavior.STRIP ) style_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_style_markers", False) + if postprocess_config["include_style_markers"] else UpdateUsfmMarkerBehavior.STRIP ) embed_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_embeds", False) + if postprocess_config["include_embeds"] else UpdateUsfmMarkerBehavior.STRIP ) marker_placement_suffix = ( "_" - + ("p" if postprocess_config.get("include_paragraph_markers", False) else "") - + ("s" if postprocess_config.get("include_style_markers", False) else "") - + ("e" if postprocess_config.get("include_embeds", False) else "") + + ("p" if postprocess_config["include_paragraph_markers"] else "") + + ("s" if postprocess_config["include_style_markers"] else "") + + ("e" if postprocess_config["include_embeds"] else "") ) with src_path.open(encoding=encoding) as f: diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 9f137352..cb8d3467 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -1,6 +1,7 @@ import logging import re from abc import ABC, abstractmethod +from collections import defaultdict from datetime import date from itertools import groupby from math import exp @@ -26,6 +27,7 @@ from .corpus import load_corpus, write_corpus from .paratext import get_book_path, get_iso, get_project_dir from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler +from .utils import merge_dict LOGGER = logging.getLogger(__package__ + ".translate") nltk.download("punkt") @@ -220,19 +222,14 @@ def translate_usfm( vrefs.insert(idx, vref) output.insert(idx, [None, None, None, None]) - # Base draft - postprocess_configs = [ - {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False} - ] + postprocess_configs + # Prepare configs: add base draft and default value + postprocess_configs = [merge_dict(defaultdict(lambda: False), ppc) for ppc in [{}] + postprocess_configs] draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)] - if any( - ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False) - for ppc in postprocess_configs - ): + if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): place_markers_handler = construct_place_markers_handler(vrefs, sentences, translated_draft) for postprocess_config in postprocess_configs: @@ -244,31 +241,29 @@ def translate_usfm( ) paragraph_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_paragraph_markers", False) + if postprocess_config["include_paragraph_markers"] else UpdateUsfmMarkerBehavior.STRIP ) style_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_style_markers", False) + if postprocess_config["include_style_markers"] else UpdateUsfmMarkerBehavior.STRIP ) embed_behavior = ( UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config.get("include_embeds", False) + if postprocess_config["include_embeds"] else UpdateUsfmMarkerBehavior.STRIP ) marker_placement_suffix = ( "_" - + ("p" if postprocess_config.get("include_paragraph_markers", False) else "") - + ("s" if postprocess_config.get("include_style_markers", False) else "") - + ("e" if postprocess_config.get("include_embeds", False) else "") + + ("p" if postprocess_config["include_paragraph_markers"] else "") + + ("s" if postprocess_config["include_style_markers"] else "") + + ("e" if postprocess_config["include_embeds"] else "") ) marker_placement_suffix = "" if len(marker_placement_suffix) == 1 else marker_placement_suffix update_block_handlers = [] - if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get( - "include_style_markers", False - ): + if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: update_block_handlers.append(place_markers_handler) # Insert translation into the USFM structure of an existing project From 0cb4595b220ad53796c4ce5e8cacd17cd083d167 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Tue, 24 Jun 2025 18:47:54 +0000 Subject: [PATCH 03/10] Do postprocessing for all books of all translate requests --- silnlp/common/postprocess_draft.py | 200 +++++++++++++++-------------- 1 file changed, 104 insertions(+), 96 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index ca9a0409..a124902b 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -17,7 +17,7 @@ UsfmTextType, parse_usfm, ) -from machine.scripture import book_id_to_number +from machine.scripture import book_number_to_id, get_chapters from transformers.trainer_utils import get_last_checkpoint from ..nmt.clearml_connection import SILClearML @@ -31,31 +31,35 @@ LOGGER = logging.getLogger(__package__ + ".postprocess_draft") -# NOTE: only using first book of first translate request for now -def get_paths_from_exp(config: Config) -> Tuple[Path, Path]: - # TODO: default to first draft in the infer folder +def get_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]: if not (config.exp_dir / "translate_config.yml").exists(): raise ValueError("Experiment translate_config.yml not found. Please use --source and --draft options instead.") - with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: - translate_config = yaml.safe_load(file)["translate"][0] - src_project = translate_config.get("src_project", next(iter(config.src_projects))) - books = translate_config["books"] - book = books[0][:3] if isinstance(books, list) else books.split(";")[0][:3] - book_num = book_id_to_number(book) - - ckpt = translate_config.get("checkpoint", "last") - if ckpt == "best": - step_str = get_best_checkpoint(config.model_dir).name[11:] - elif ckpt == "last": - step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] - else: - step_str = str(ckpt) + translate_requests = yaml.safe_load(file).get("translate", []) - return ( - get_book_path(src_project, book), - config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM", - ) + src_paths = [] + draft_paths = [] + for translate_request in translate_requests: + src_project = translate_request.get("src_project", next(iter(config.src_projects))) + + ckpt = translate_request.get("checkpoint", "last") + if ckpt == "best": + step_str = get_best_checkpoint(config.model_dir).name[11:] + elif ckpt == "last": + step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] + else: + step_str = str(ckpt) + + book_nums = get_chapters(translate_request.get("books", [])).keys() + for book_num in book_nums: + book = book_number_to_id(book_num) + + src_paths.append(get_book_path(src_project, book)) + draft_paths.append( + config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" + ) + + return src_paths, draft_paths def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: @@ -165,12 +169,12 @@ def main() -> None: config = yaml.safe_load(file) config = create_config(exp_dir, config) - src_path, draft_path = get_paths_from_exp(config) + src_paths, draft_paths = get_paths_from_exp(config) elif args.clearml_queue is not None: raise ValueError("Must use --experiment option to use ClearML.") else: - src_path = Path(args.source.replace("\\", "/")) - draft_path = Path(args.draft.replace("\\", "/")) + src_paths = [Path(args.source.replace("\\", "/"))] + draft_paths = [Path(args.draft.replace("\\", "/"))] # If no postprocessing options are used, use any postprocessing requests in the experiment's translate config if args.include_paragraph_markers or args.include_style_markers or args.include_embeds: @@ -194,81 +198,85 @@ def main() -> None: LOGGER.info("Please use at least one postprocessing option.") exit() - if str(src_path).startswith(str(get_project_dir(""))): - settings = FileParatextProjectSettingsParser(src_path.parent).parse() - stylesheet = settings.stylesheet - encoding = settings.encoding - book = settings.get_book_id(src_path.name) - else: - stylesheet = UsfmStylesheet("usfm.sty") - encoding = "utf-8-sig" - book = args.book - if book is None: - raise ValueError( - "--book argument must be passed if the source file is not in a Paratext project directory." - ) + for src_path, draft_path in zip(src_paths, draft_paths): + if str(src_path).startswith(str(get_project_dir(""))): + settings = FileParatextProjectSettingsParser(src_path.parent).parse() + stylesheet = settings.stylesheet + encoding = settings.encoding + book = settings.get_book_id(src_path.name) + else: + stylesheet = UsfmStylesheet("usfm.sty") + encoding = "utf-8-sig" + book = args.book + if book is None: + raise ValueError( + "--book argument must be passed if the source file is not in a Paratext project directory." + ) + + src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) + draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) + + if len(src_refs) != len(draft_refs): + LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") + continue + for src_ref, draft_ref in zip(src_refs, draft_refs): + if src_ref.to_relaxed() != draft_ref.to_relaxed(): + LOGGER.warning( + f"Can't process {src_path} and {draft_path}: Mismatched ref: {src_ref} {draft_ref}. Files must have the exact same USFM structure" + ) + continue + + if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): + place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) - src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) - draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) + for postprocess_config in postprocess_configs: + update_block_handlers = [] + if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: + update_block_handlers.append(place_markers_handler) - if len(src_refs) != len(draft_refs): - raise ValueError("Different number of verses/references between source and draft.") - for src_ref, draft_ref in zip(src_refs, draft_refs): - if src_ref.to_relaxed() != draft_ref.to_relaxed(): - raise ValueError( - f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}" + paragraph_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config["include_paragraph_markers"] + else UpdateUsfmMarkerBehavior.STRIP + ) + style_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config["include_style_markers"] + else UpdateUsfmMarkerBehavior.STRIP ) + embed_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config["include_embeds"] + else UpdateUsfmMarkerBehavior.STRIP + ) + marker_placement_suffix = ( + "_" + + ("p" if postprocess_config["include_paragraph_markers"] else "") + + ("s" if postprocess_config["include_style_markers"] else "") + + ("e" if postprocess_config["include_embeds"] else "") + ) + + with src_path.open(encoding=encoding) as f: + usfm = f.read() + handler = UpdateUsfmParserHandler( + rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)], + id_text=book, + text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + parse_usfm(usfm, handler) + usfm_out = handler.get_usfm() + + # TODO: back a level + usfm_out = insert_draft_remarks(usfm_out, draft_remarks) - if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): - place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) - - for postprocess_config in postprocess_configs: - update_block_handlers = [] - if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: - update_block_handlers.append(place_markers_handler) - - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_paragraph_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_style_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - embed_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_embeds"] - else UpdateUsfmMarkerBehavior.STRIP - ) - marker_placement_suffix = ( - "_" - + ("p" if postprocess_config["include_paragraph_markers"] else "") - + ("s" if postprocess_config["include_style_markers"] else "") - + ("e" if postprocess_config["include_embeds"] else "") - ) - - with src_path.open(encoding=encoding) as f: - usfm = f.read() - handler = UpdateUsfmParserHandler( - rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)], - id_text=book, - text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, - ) - parse_usfm(usfm, handler) - usfm_out = handler.get_usfm() - - usfm_out = insert_draft_remarks(usfm_out, draft_remarks) - - out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent - out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" - with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: - f.write(usfm_out) + out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent + out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" + with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: + f.write(usfm_out) if __name__ == "__main__": From 0b7d0c6bc54dddc179bd1fbc2f34588f4c53306f Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Tue, 24 Jun 2025 19:57:16 +0000 Subject: [PATCH 04/10] Organize into functions --- silnlp/common/postprocess_draft.py | 180 ++++++++++++++++------------- 1 file changed, 101 insertions(+), 79 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index a124902b..a9a06f82 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -3,7 +3,7 @@ import re from collections import defaultdict from pathlib import Path -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import yaml from machine.corpora import ( @@ -15,6 +15,7 @@ UsfmFileText, UsfmStylesheet, UsfmTextType, + UsfmUpdateBlockHandler, parse_usfm, ) from machine.scripture import book_number_to_id, get_chapters @@ -93,6 +94,102 @@ def get_sentences( return sents, refs, draft_remarks +def postprocess_drafts( + src_path: Path, + draft_path: Path, + postprocess_configs: List[Dict[str, bool]], + book: Optional[str] = None, + out_dir: Optional[Path] = None, +) -> None: + if str(src_path).startswith(str(get_project_dir(""))): + settings = FileParatextProjectSettingsParser(src_path.parent).parse() + stylesheet = settings.stylesheet + encoding = settings.encoding + book = settings.get_book_id(src_path.name) + else: + stylesheet = UsfmStylesheet("usfm.sty") + encoding = "utf-8-sig" + if book is None: + raise ValueError( + "--book argument must be passed if the source file is not in a Paratext project directory." + ) + + src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) + draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) + + # Verify reference parity + if len(src_refs) != len(draft_refs): + LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") + return + for src_ref, draft_ref in zip(src_refs, draft_refs): + if src_ref.to_relaxed() != draft_ref.to_relaxed(): + LOGGER.warning( + f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" + ) + return + + # Initialize UsfmUpdateBlockHandlers as necessary + if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): + place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) + + with src_path.open(encoding=encoding) as f: + usfm = f.read() + rows = [([ref], sent) for ref, sent in zip(src_refs, draft_sents)] + + for postprocess_config in postprocess_configs: + update_block_handlers = [] + if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: + update_block_handlers.append(place_markers_handler) + + usfm_out = update_draft(usfm, rows, postprocess_config, update_block_handlers) + usfm_out = insert_draft_remarks(usfm_out, draft_remarks) + + marker_placement_suffix = ( + "_" + + ("p" if postprocess_config["include_paragraph_markers"] else "") + + ("s" if postprocess_config["include_style_markers"] else "") + + ("e" if postprocess_config["include_embeds"] else "") + ) + if not out_dir: + out_dir = draft_path.parent + out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" + with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: + f.write(usfm_out) + + +def update_draft( + usfm: str, + rows: List[Tuple[List[ScriptureRef], str]], + postprocess_config: Dict[str, bool], + update_block_handlers: List[UsfmUpdateBlockHandler] = [], +) -> str: + paragraph_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config["include_paragraph_markers"] + else UpdateUsfmMarkerBehavior.STRIP + ) + style_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE + if postprocess_config["include_style_markers"] + else UpdateUsfmMarkerBehavior.STRIP + ) + embed_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE if postprocess_config["include_embeds"] else UpdateUsfmMarkerBehavior.STRIP + ) + + handler = UpdateUsfmParserHandler( + rows=rows, + id_text=rows[0][0][0].book, + text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + parse_usfm(usfm, handler) + return handler.get_usfm() + + def main() -> None: parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.") parser.add_argument( @@ -198,85 +295,10 @@ def main() -> None: LOGGER.info("Please use at least one postprocessing option.") exit() + if args.output_folder: + args.output_folder = Path(args.output_folder.replace("\\", "/")) for src_path, draft_path in zip(src_paths, draft_paths): - if str(src_path).startswith(str(get_project_dir(""))): - settings = FileParatextProjectSettingsParser(src_path.parent).parse() - stylesheet = settings.stylesheet - encoding = settings.encoding - book = settings.get_book_id(src_path.name) - else: - stylesheet = UsfmStylesheet("usfm.sty") - encoding = "utf-8-sig" - book = args.book - if book is None: - raise ValueError( - "--book argument must be passed if the source file is not in a Paratext project directory." - ) - - src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) - draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) - - if len(src_refs) != len(draft_refs): - LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") - continue - for src_ref, draft_ref in zip(src_refs, draft_refs): - if src_ref.to_relaxed() != draft_ref.to_relaxed(): - LOGGER.warning( - f"Can't process {src_path} and {draft_path}: Mismatched ref: {src_ref} {draft_ref}. Files must have the exact same USFM structure" - ) - continue - - if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): - place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) - - for postprocess_config in postprocess_configs: - update_block_handlers = [] - if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: - update_block_handlers.append(place_markers_handler) - - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_paragraph_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_style_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - embed_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_embeds"] - else UpdateUsfmMarkerBehavior.STRIP - ) - marker_placement_suffix = ( - "_" - + ("p" if postprocess_config["include_paragraph_markers"] else "") - + ("s" if postprocess_config["include_style_markers"] else "") - + ("e" if postprocess_config["include_embeds"] else "") - ) - - with src_path.open(encoding=encoding) as f: - usfm = f.read() - handler = UpdateUsfmParserHandler( - rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)], - id_text=book, - text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, - ) - parse_usfm(usfm, handler) - usfm_out = handler.get_usfm() - - # TODO: back a level - usfm_out = insert_draft_remarks(usfm_out, draft_remarks) - - out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent - out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" - with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: - f.write(usfm_out) + postprocess_drafts(src_path, draft_path, postprocess_configs, args.book, args.output_folder) if __name__ == "__main__": From cbaecd4b08a92529496aff009a46fb0d971f3f80 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Tue, 24 Jun 2025 20:06:31 +0000 Subject: [PATCH 05/10] Add remark for the postprocessing options used --- silnlp/common/postprocess_draft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index a9a06f82..1c4d137e 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -142,7 +142,9 @@ def postprocess_drafts( update_block_handlers.append(place_markers_handler) usfm_out = update_draft(usfm, rows, postprocess_config, update_block_handlers) - usfm_out = insert_draft_remarks(usfm_out, draft_remarks) + + postprocess_remark = f"Post-processing options used: {' '.join(opt for opt in postprocess_config.keys() if postprocess_config[opt])}" + usfm_out = insert_draft_remarks(usfm_out, draft_remarks + [postprocess_remark]) marker_placement_suffix = ( "_" From ce324bb18d6589883bf58c08e98dd1f08bd604f9 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Fri, 27 Jun 2025 08:43:05 +0000 Subject: [PATCH 06/10] Refactor postprocessing functionality --- silnlp/common/postprocess_draft.py | 209 +++-------------------------- silnlp/common/postprocess_utils.py | 111 +++++++++++++++ silnlp/common/translator.py | 99 +++++--------- silnlp/common/usfm_preservation.py | 48 ------- silnlp/nmt/experiment.py | 7 +- silnlp/nmt/postprocess.py | 201 +++++++++++++++++++++++++++ silnlp/nmt/translate.py | 16 ++- 7 files changed, 376 insertions(+), 315 deletions(-) create mode 100644 silnlp/common/postprocess_utils.py delete mode 100644 silnlp/common/usfm_preservation.py create mode 100644 silnlp/nmt/postprocess.py diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 1c4d137e..5f1c381f 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -1,197 +1,19 @@ import argparse import logging -import re -from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple import yaml -from machine.corpora import ( - FileParatextProjectSettingsParser, - ScriptureRef, - UpdateUsfmMarkerBehavior, - UpdateUsfmParserHandler, - UpdateUsfmTextBehavior, - UsfmFileText, - UsfmStylesheet, - UsfmTextType, - UsfmUpdateBlockHandler, - parse_usfm, -) -from machine.scripture import book_number_to_id, get_chapters -from transformers.trainer_utils import get_last_checkpoint from ..nmt.clearml_connection import SILClearML -from ..nmt.config import Config -from ..nmt.config_utils import create_config -from ..nmt.hugging_face_config import get_best_checkpoint -from .paratext import book_file_name_digits, get_book_path, get_project_dir -from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler -from .utils import get_mt_exp_dir, merge_dict +from ..nmt.config_utils import load_config +from ..nmt.postprocess import PostprocessHandler, postprocess_draft +from .paratext import get_project_dir +from .postprocess_utils import get_draft_paths_from_exp +from .utils import get_mt_exp_dir LOGGER = logging.getLogger(__package__ + ".postprocess_draft") -def get_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]: - if not (config.exp_dir / "translate_config.yml").exists(): - raise ValueError("Experiment translate_config.yml not found. Please use --source and --draft options instead.") - with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: - translate_requests = yaml.safe_load(file).get("translate", []) - - src_paths = [] - draft_paths = [] - for translate_request in translate_requests: - src_project = translate_request.get("src_project", next(iter(config.src_projects))) - - ckpt = translate_request.get("checkpoint", "last") - if ckpt == "best": - step_str = get_best_checkpoint(config.model_dir).name[11:] - elif ckpt == "last": - step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] - else: - step_str = str(ckpt) - - book_nums = get_chapters(translate_request.get("books", [])).keys() - for book_num in book_nums: - book = book_number_to_id(book_num) - - src_paths.append(get_book_path(src_project, book)) - draft_paths.append( - config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" - ) - - return src_paths, draft_paths - - -def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: - lines = usfm.split("\n") - remark_lines = [f"\\rem {r}" for r in remarks] - return "\n".join(lines[:1] + remark_lines + lines[1:]) - - -def get_sentences( - book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = [] -) -> Tuple[List[str], List[ScriptureRef], List[str]]: - sents = [] - refs = [] - draft_remarks = [] - for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True): - marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else "" - if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s) - draft_remarks.append(sent.text) - continue - if ( - marker in PARAGRAPH_TYPE_EMBEDS - or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT - # or len(sent.text.strip()) == 0 - or (len(chapters) > 0 and sent.ref.chapter_num not in chapters) - ): - continue - - sents.append(re.sub(" +", " ", sent.text.strip())) - refs.append(sent.ref) - - return sents, refs, draft_remarks - - -def postprocess_drafts( - src_path: Path, - draft_path: Path, - postprocess_configs: List[Dict[str, bool]], - book: Optional[str] = None, - out_dir: Optional[Path] = None, -) -> None: - if str(src_path).startswith(str(get_project_dir(""))): - settings = FileParatextProjectSettingsParser(src_path.parent).parse() - stylesheet = settings.stylesheet - encoding = settings.encoding - book = settings.get_book_id(src_path.name) - else: - stylesheet = UsfmStylesheet("usfm.sty") - encoding = "utf-8-sig" - if book is None: - raise ValueError( - "--book argument must be passed if the source file is not in a Paratext project directory." - ) - - src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) - draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) - - # Verify reference parity - if len(src_refs) != len(draft_refs): - LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") - return - for src_ref, draft_ref in zip(src_refs, draft_refs): - if src_ref.to_relaxed() != draft_ref.to_relaxed(): - LOGGER.warning( - f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" - ) - return - - # Initialize UsfmUpdateBlockHandlers as necessary - if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): - place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents) - - with src_path.open(encoding=encoding) as f: - usfm = f.read() - rows = [([ref], sent) for ref, sent in zip(src_refs, draft_sents)] - - for postprocess_config in postprocess_configs: - update_block_handlers = [] - if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: - update_block_handlers.append(place_markers_handler) - - usfm_out = update_draft(usfm, rows, postprocess_config, update_block_handlers) - - postprocess_remark = f"Post-processing options used: {' '.join(opt for opt in postprocess_config.keys() if postprocess_config[opt])}" - usfm_out = insert_draft_remarks(usfm_out, draft_remarks + [postprocess_remark]) - - marker_placement_suffix = ( - "_" - + ("p" if postprocess_config["include_paragraph_markers"] else "") - + ("s" if postprocess_config["include_style_markers"] else "") - + ("e" if postprocess_config["include_embeds"] else "") - ) - if not out_dir: - out_dir = draft_path.parent - out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}" - with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: - f.write(usfm_out) - - -def update_draft( - usfm: str, - rows: List[Tuple[List[ScriptureRef], str]], - postprocess_config: Dict[str, bool], - update_block_handlers: List[UsfmUpdateBlockHandler] = [], -) -> str: - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_paragraph_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_style_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - embed_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE if postprocess_config["include_embeds"] else UpdateUsfmMarkerBehavior.STRIP - ) - - handler = UpdateUsfmParserHandler( - rows=rows, - id_text=rows[0][0][0].book, - text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, - ) - parse_usfm(usfm, handler) - return handler.get_usfm() - - def main() -> None: parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.") parser.add_argument( @@ -257,23 +79,28 @@ def main() -> None: experiment = args.experiment.replace("\\", "/") if args.experiment else None if experiment and get_mt_exp_dir(experiment).exists(): - exp_dir = get_mt_exp_dir(experiment) if args.clearml_queue is not None: if "cpu" not in args.clearml_queue: raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.") clearml = SILClearML(experiment, args.clearml_queue) config = clearml.config else: - with (exp_dir / "config.yml").open("r", encoding="utf-8") as file: - config = yaml.safe_load(file) - config = create_config(exp_dir, config) + config = load_config(experiment) - src_paths, draft_paths = get_paths_from_exp(config) + if not (config.exp_dir / "translate_config.yml").exists(): + raise ValueError( + "Experiment translate_config.yml not found. Please use --source and --draft options instead." + ) + src_paths, draft_paths = get_draft_paths_from_exp(config) elif args.clearml_queue is not None: raise ValueError("Must use --experiment option to use ClearML.") else: src_paths = [Path(args.source.replace("\\", "/"))] draft_paths = [Path(args.draft.replace("\\", "/"))] + if not str(src_paths[0]).startswith(str(get_project_dir(""))) and args.book is None: + raise ValueError( + "--book argument must be passed if the source file is not in a Paratext project directory." + ) # If no postprocessing options are used, use any postprocessing requests in the experiment's translate config if args.include_paragraph_markers or args.include_style_markers or args.include_embeds: @@ -289,18 +116,18 @@ def main() -> None: LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.") with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: postprocess_configs = yaml.safe_load(file).get("postprocess", []) - postprocess_configs = [merge_dict(defaultdict(lambda: False), ppc) for ppc in postprocess_configs] if len(postprocess_configs) == 0: - LOGGER.info("No postprocessing requests found.") + LOGGER.info("No postprocessing requests found in translate config.") exit() else: LOGGER.info("Please use at least one postprocessing option.") exit() + postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) if args.output_folder: args.output_folder = Path(args.output_folder.replace("\\", "/")) for src_path, draft_path in zip(src_paths, draft_paths): - postprocess_drafts(src_path, draft_path, postprocess_configs, args.book, args.output_folder) + postprocess_draft(src_path, draft_path, postprocess_handler, args.book, args.output_folder) if __name__ == "__main__": diff --git a/silnlp/common/postprocess_utils.py b/silnlp/common/postprocess_utils.py new file mode 100644 index 00000000..5fa73311 --- /dev/null +++ b/silnlp/common/postprocess_utils.py @@ -0,0 +1,111 @@ +import logging +import re +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import List, Tuple + +import yaml +from machine.corpora import ScriptureRef, UsfmFileText, UsfmStylesheet, UsfmTextType +from machine.scripture import book_number_to_id, get_chapters +from machine.translation import WordAlignmentMatrix +from transformers.trainer_utils import get_last_checkpoint + +from ..alignment.eflomal import to_word_alignment_matrix +from ..alignment.utils import compute_alignment_scores +from ..nmt.config import Config +from ..nmt.hugging_face_config import get_best_checkpoint +from .corpus import load_corpus, write_corpus +from .paratext import book_file_name_digits, get_book_path + +LOGGER = logging.getLogger(__package__ + ".postprocess_utils") + +# Marker "type" is as defined by the UsfmTokenType given to tokens by the UsfmTokenizer, +# which mostly aligns with a marker's StyleType in the USFM stylesheet +CHARACTER_TYPE_EMBEDS = ["fig", "fm", "jmp", "rq", "va", "vp", "xt", "xtSee", "xtSeeAlso"] +PARAGRAPH_TYPE_EMBEDS = ["lit", "r", "rem"] +NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS + + +def get_alignment_matrices( + src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" +) -> List[WordAlignmentMatrix]: + with TemporaryDirectory() as td: + align_path = Path(td, "sym-align.txt") + write_corpus(Path(td, "src_align.txt"), src_sents) + write_corpus(Path(td, "trg_align.txt"), trg_sents) + compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path) + + return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] + + +# NOTE: to be replaced by new machine.py remark functionality +def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: + lines = usfm.split("\n") + remark_lines = [f"\\rem {r}" for r in remarks] + return "\n".join(lines[:1] + remark_lines + lines[1:]) + + +# Takes the path to a USFM file and the relevant info to parse it +# and returns the text of all non-embed sentences and their respective references, +# along with any remarks (\rem) that were inserted at the beginning of the file +def get_sentences( + book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = [] +) -> Tuple[List[str], List[ScriptureRef], List[str]]: + sents = [] + refs = [] + draft_remarks = [] + for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True): + marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else "" + if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s) + draft_remarks.append(sent.text) + continue + if ( + marker in PARAGRAPH_TYPE_EMBEDS + or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT + or (len(chapters) > 0 and sent.ref.chapter_num not in chapters) + ): + continue + + sents.append(re.sub(" +", " ", sent.text.strip())) + refs.append(sent.ref) + + return sents, refs, draft_remarks + + +# Get the paths of all drafts that would be produced by an experiment's translate config and that exist +def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]: + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: + translate_requests = yaml.safe_load(file).get("translate", []) + + src_paths = [] + draft_paths = [] + for translate_request in translate_requests: + src_project = translate_request.get("src_project", next(iter(config.src_projects))) + + ckpt = translate_request.get("checkpoint", "last") + if ckpt == "best": + step_str = get_best_checkpoint(config.model_dir).name[11:] + elif ckpt == "last": + step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] + else: + step_str = str(ckpt) + + book_nums = get_chapters(translate_request.get("books", [])).keys() + for book_num in book_nums: + book = book_number_to_id(book_num) + + src_path = get_book_path(src_project, book) + draft_path = ( + config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" + ) + if draft_path.exists(): + src_paths.append(src_path) + draft_paths.append(draft_path) + elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts + for i in range(1, config.infer.get("num_drafts", 1) + 1): + src_paths.append(src_path) + draft_paths.append(draft_path.with_suffix(f".{i}{draft_path.suffix}")) + else: + LOGGER.warning(f"Draft not found: {draft_path}") + + return src_paths, draft_paths diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index cb8d3467..f7f799f7 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -1,12 +1,11 @@ import logging import re from abc import ABC, abstractmethod -from collections import defaultdict from datetime import date from itertools import groupby from math import exp from pathlib import Path -from typing import Dict, Iterable, List, Optional +from typing import Iterable, List, Optional import docx import nltk @@ -14,7 +13,6 @@ from machine.corpora import ( FileParatextProjectSettingsParser, FileParatextProjectTextUpdater, - UpdateUsfmMarkerBehavior, UpdateUsfmParserHandler, UpdateUsfmTextBehavior, UsfmFileText, @@ -24,31 +22,24 @@ ) from machine.scripture import VerseRef +from ..nmt.postprocess import PostprocessHandler from .corpus import load_corpus, write_corpus from .paratext import get_book_path, get_iso, get_project_dir -from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler -from .utils import merge_dict +from .postprocess_utils import PARAGRAPH_TYPE_EMBEDS LOGGER = logging.getLogger(__package__ + ".translate") nltk.download("punkt") -def insert_draft_remark( - usfm: str, - book: str, - description: str, - experiment_ckpt_str: str, -) -> str: - remark = f"\\rem This draft of {book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." - +def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: lines = usfm.split("\n") insert_idx = ( 1 + (len(lines) > 1 and (lines[1].startswith("\\ide") or lines[1].startswith("\\usfm"))) + (len(lines) > 2 and (lines[2].startswith("\\ide") or lines[2].startswith("\\usfm"))) ) - lines.insert(insert_idx, remark) - return "\n".join(lines) + remarks = [f"\\rem {r}" for r in remarks] + return "\n".join(lines[:insert_idx] + remarks + lines[insert_idx:]) # A group of multiple translations of a single sentence @@ -139,7 +130,7 @@ def translate_book( produce_multiple_translations: bool = False, chapters: List[int] = [], trg_project: Optional[str] = None, - postprocess_configs: List[Dict[str, bool]] = [], + postprocess_handler: PostprocessHandler = PostprocessHandler(), experiment_ckpt_str: str = "", ) -> None: book_path = get_book_path(src_project, book) @@ -156,7 +147,7 @@ def translate_book( produce_multiple_translations, chapters, trg_project, - postprocess_configs, + postprocess_handler, experiment_ckpt_str, ) @@ -169,7 +160,7 @@ def translate_usfm( produce_multiple_translations: bool = False, chapters: List[int] = [], trg_project: Optional[str] = None, - postprocess_configs: List[Dict[str, bool]] = [], + postprocess_handler: PostprocessHandler = PostprocessHandler(), experiment_ckpt_str: str = "", ) -> None: # Create UsfmFileText object for source @@ -222,50 +213,17 @@ def translate_usfm( vrefs.insert(idx, vref) output.insert(idx, [None, None, None, None]) - # Prepare configs: add base draft and default value - postprocess_configs = [merge_dict(defaultdict(lambda: False), ppc) for ppc in [{}] + postprocess_configs] + text_behavior = ( + UpdateUsfmTextBehavior.PREFER_NEW if trg_project is not None else UpdateUsfmTextBehavior.STRIP_EXISTING + ) draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)] - if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs): - place_markers_handler = construct_place_markers_handler(vrefs, sentences, translated_draft) - - for postprocess_config in postprocess_configs: - # Update behaviors - text_behavior = ( - UpdateUsfmTextBehavior.PREFER_NEW - if trg_project is not None - else UpdateUsfmTextBehavior.STRIP_EXISTING - ) - paragraph_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_paragraph_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - style_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_style_markers"] - else UpdateUsfmMarkerBehavior.STRIP - ) - embed_behavior = ( - UpdateUsfmMarkerBehavior.PRESERVE - if postprocess_config["include_embeds"] - else UpdateUsfmMarkerBehavior.STRIP - ) - marker_placement_suffix = ( - "_" - + ("p" if postprocess_config["include_paragraph_markers"] else "") - + ("s" if postprocess_config["include_style_markers"] else "") - + ("e" if postprocess_config["include_embeds"] else "") - ) - marker_placement_suffix = "" if len(marker_placement_suffix) == 1 else marker_placement_suffix - - update_block_handlers = [] - if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]: - update_block_handlers.append(place_markers_handler) + postprocess_handler.create_update_block_handlers(vrefs, sentences, translated_draft) + for config in postprocess_handler.configs: # Insert translation into the USFM structure of an existing project # If the target project is not the same as the translated file's original project, # no verses outside of the ones translated will be overwritten @@ -277,10 +235,10 @@ def translate_usfm( book_id=src_file_text.id, rows=rows, text_behavior=text_behavior, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=config.update_block_handlers, ) if usfm_out is None: @@ -294,18 +252,25 @@ def translate_usfm( rows=rows, id_text=vrefs[0].book, text_behavior=text_behavior, - paragraph_behavior=paragraph_behavior, - embed_behavior=embed_behavior, - style_behavior=style_behavior, - update_block_handlers=update_block_handlers, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=config.update_block_handlers, ) parse_usfm(usfm, handler) usfm_out = handler.get_usfm() - # Insert draft remark and write to output path + # Insert draft remarks description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str) - trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + marker_placement_suffix) + remarks = [ + f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." + ] + if len(config.get_postprocess_remark()) > 0: + remarks.append(config.get_postprocess_remark()) + usfm_out = insert_draft_remarks(usfm_out, remarks) + + # Construct output file name write to file + trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + config.get_postprocess_suffix()) if produce_multiple_translations: trg_draft_file_path = trg_draft_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}") with trg_draft_file_path.open( diff --git a/silnlp/common/usfm_preservation.py b/silnlp/common/usfm_preservation.py deleted file mode 100644 index 70d67e04..00000000 --- a/silnlp/common/usfm_preservation.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import abstractmethod -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import List - -from machine.corpora import PlaceMarkersAlignmentInfo, PlaceMarkersUsfmUpdateBlockHandler, ScriptureRef -from machine.tokenization import LatinWordTokenizer -from machine.translation import WordAlignmentMatrix - -from ..alignment.eflomal import to_word_alignment_matrix -from ..alignment.utils import compute_alignment_scores -from .corpus import load_corpus, write_corpus - -# Marker "type" is as defined by the UsfmTokenType given to tokens by the UsfmTokenizer, -# which mostly aligns with a marker's StyleType in the USFM stylesheet -CHARACTER_TYPE_EMBEDS = ["fig", "fm", "jmp", "rq", "va", "vp", "xt", "xtSee", "xtSeeAlso"] -PARAGRAPH_TYPE_EMBEDS = ["lit", "r", "rem"] -NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS - - -def get_alignment_matrices( - src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" -) -> List[WordAlignmentMatrix]: - with TemporaryDirectory() as td: - align_path = Path(td, "sym-align.txt") - write_corpus(Path(td, "src_align.txt"), src_sents) - write_corpus(Path(td, "trg_align.txt"), trg_sents) - compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path) - - return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] - - -def construct_place_markers_handler( - refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal" -) -> PlaceMarkersUsfmUpdateBlockHandler: - align_info = [] - tokenizer = LatinWordTokenizer() - alignments = get_alignment_matrices(source, translation, aligner) - for ref, s, t, alignment in zip(refs, source, translation, alignments): - align_info.append( - PlaceMarkersAlignmentInfo( - refs=[str(ref)], - source_tokens=list(tokenizer.tokenize(s)), - translation_tokens=list(tokenizer.tokenize(t)), - alignment=alignment, - ) - ) - return PlaceMarkersUsfmUpdateBlockHandler(align_info) diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 9c6d3d6b..d2f632a1 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -10,6 +10,7 @@ from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML from .config import Config, get_mt_exp_dir +from .postprocess import PostprocessHandler from .test import _SUPPORTED_SCORERS, test from .translate import TranslationTask @@ -81,6 +82,8 @@ def translate(self): with (self.config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: translate_configs = yaml.safe_load(file) + postprocess_handler = PostprocessHandler(translate_configs.get("postprocess", [])) + for config in translate_configs.get("translate", []): translator = TranslationTask( name=self.name, checkpoint=config.get("checkpoint", "last"), commit=self.commit @@ -95,7 +98,7 @@ def translate(self): config.get("trg_project"), config.get("trg_iso"), self.produce_multiple_translations, - translate_configs.get("postprocess", []), + postprocess_handler, ) elif config.get("src_prefix"): translator.translate_text_files( @@ -114,7 +117,7 @@ def translate(self): config.get("src_iso"), config.get("trg_iso"), self.produce_multiple_translations, - translate_configs.get("postprocess", []), + postprocess_handler, ) else: raise RuntimeError("A Scripture book, file, or file prefix must be specified for translation.") diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py new file mode 100644 index 00000000..af412eab --- /dev/null +++ b/silnlp/nmt/postprocess.py @@ -0,0 +1,201 @@ +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Optional, Union + +import yaml +from machine.corpora import ( + FileParatextProjectSettingsParser, + PlaceMarkersAlignmentInfo, + PlaceMarkersUsfmUpdateBlockHandler, + ScriptureRef, + UpdateUsfmMarkerBehavior, + UpdateUsfmParserHandler, + UpdateUsfmTextBehavior, + UsfmStylesheet, + UsfmUpdateBlockHandler, + parse_usfm, +) +from machine.tokenization import LatinWordTokenizer + +from ..common.paratext import get_project_dir +from ..common.postprocess_utils import ( + get_alignment_matrices, + get_draft_paths_from_exp, + get_sentences, + insert_draft_remarks, +) +from ..common.utils import get_git_revision_hash, merge_dict +from .clearml_connection import SILClearML +from .config_utils import load_config + +LOGGER = logging.getLogger(__package__ + ".postprocess") + +POSTPROCESS_OPTIONS = {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False} +POSTPROCESS_SUFFIX_CHARS = ["p", "s", "e"] + + +class PostprocessConfig: + update_block_handlers: List[UsfmUpdateBlockHandler] = [] + + def __init__(self, config: Dict[str, Union[bool, str]]) -> None: + # TODO: need to make a copy of the default dict? + self._config = merge_dict(POSTPROCESS_OPTIONS, config) + + def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior: + return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP + + def get_paragraph_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_paragraph_markers"]) + + def get_style_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_style_markers"]) + + def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_embeds"]) + + def get_postprocess_suffix(self) -> str: + suffix = "_" + for option, char in zip(POSTPROCESS_OPTIONS, POSTPROCESS_SUFFIX_CHARS): + if self._config[option]: + suffix += char + + return suffix if len(suffix) > 1 else "" + + def get_postprocess_remark(self) -> str: + return f"Post-processing options used: {' '.join(opt for opt in POSTPROCESS_OPTIONS if self._config[opt])}" + + +class PostprocessHandler: + # TODO: check if one of the configs is already all default? + def __init__(self, configs: List[Dict[str, Union[bool, str]]] = [], include_base: bool = True) -> None: + self.configs = [PostprocessConfig(config) for config in ([{}] if include_base else []) + configs] + + # NOTE: Update block handlers may need to be created/recreated at different times + # For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment, + # but other handlers may only need to be created once overall, or once per source project. + # This may change what part of the process we want this function to be called at + def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[str], translation: List[str]) -> None: + # USFM marker placement handler needs to be recreated for each draft + if any(config["include_paragraph_markers"] or config["include_style_markers"] for config in self.configs): + place_markers_handler = self._construct_place_markers_handler(refs, source, translation) + + for config in self.configs: + # TODO: make sure the configs are changing + if config["include_paragraph_markers"] or config["include_style_markers"]: + if len(config.update_block_handlers) == 0: + config.update_block_handlers.append(place_markers_handler) + else: # NOTE: this assumes a set order of update block handlers + config.update_block_handlers[0] = place_markers_handler + + def _construct_place_markers_handler( + self, refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal" + ) -> PlaceMarkersUsfmUpdateBlockHandler: + align_info = [] + tokenizer = LatinWordTokenizer() + alignments = get_alignment_matrices(source, translation, aligner) + for ref, s, t, alignment in zip(refs, source, translation, alignments): + align_info.append( + PlaceMarkersAlignmentInfo( + refs=[str(ref)], + source_tokens=list(tokenizer.tokenize(s)), + translation_tokens=list(tokenizer.tokenize(t)), + alignment=alignment, + ) + ) + return PlaceMarkersUsfmUpdateBlockHandler(align_info) + + +def postprocess_draft( + src_path: Path, + draft_path: Path, + postprocess_handler: PostprocessHandler, + book: Optional[str] = None, + out_dir: Optional[Path] = None, +) -> None: + if str(src_path).startswith(str(get_project_dir(""))): + settings = FileParatextProjectSettingsParser(src_path.parent).parse() + stylesheet = settings.stylesheet + encoding = settings.encoding + book = settings.get_book_id(src_path.name) + else: + stylesheet = UsfmStylesheet("usfm.sty") + encoding = "utf-8-sig" + + src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) + draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) + + # Verify reference parity + if len(src_refs) != len(draft_refs): + LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") + return + for src_ref, draft_ref in zip(src_refs, draft_refs): + if src_ref.to_relaxed() != draft_ref.to_relaxed(): + LOGGER.warning( + f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" + ) + return + + postprocess_handler.create_update_block_handlers(src_refs, src_sents, draft_sents) + + with src_path.open(encoding=encoding) as f: + usfm = f.read() + rows = [([ref], sent) for ref, sent in zip(src_refs, draft_sents)] + + for config in postprocess_handler.configs: + handler = UpdateUsfmParserHandler( + rows=rows, + id_text=book, + text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=config.update_block_handlers, + ) + parse_usfm(usfm, handler) + usfm_out = handler.get_usfm() + + usfm_out = insert_draft_remarks(usfm_out, draft_remarks + [config.get_postprocess_remark()]) + + if not out_dir: + out_dir = draft_path.parent + out_path = out_dir / f"{draft_path.stem}{config.get_postprocess_suffix()}{draft_path.suffix}" + with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: + f.write(usfm_out) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Postprocess the drafts created by an NMT model") + parser.add_argument("experiment", help="Experiment name") + parser.add_argument( + "--clearml-queue", + default=None, + type=str, + help="Run remotely on ClearML queue. Default: None - don't register with ClearML. The queue 'local' will run " + + "it locally and register it with ClearML.", + ) + args = parser.parse_args() + + get_git_revision_hash() + + if args.clearml_queue is not None: + if "cpu" not in args.clearml_queue: + raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.") + clearml = SILClearML(args.experiment, args.clearml_queue) + config = clearml.config + else: + config = load_config(args.experiment.replace("\\", "/")) + config.set_seed() + + src_paths, draft_paths = get_draft_paths_from_exp(config) + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: + postprocess_configs = yaml.safe_load(file).get("postprocess", []) + + postprocess_handler = PostprocessHandler(postprocess_configs) + + for src_path, draft_path in zip(src_paths, draft_paths): + postprocess_draft(src_path, draft_path, postprocess_handler) + + +if __name__ == "__main__": + main() diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index f14924c7..e38bfdb4 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union from machine.scripture import VerseRef, book_number_to_id, get_chapters @@ -14,6 +14,7 @@ from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML from .config import CheckpointType, Config, NMTModel +from .postprocess import PostprocessHandler LOGGER = logging.getLogger(__package__ + ".translate") @@ -54,7 +55,7 @@ def translate_books( trg_project: Optional[str], trg_iso: Optional[str], produce_multiple_translations: bool = False, - postprocess_configs: List[Dict[str, bool]] = [], + postprocess_handler: PostprocessHandler = PostprocessHandler(), ): book_nums = get_chapters(books) translator, config, step_str = self._init_translation_task( @@ -110,7 +111,7 @@ def translate_books( produce_multiple_translations, chapters, trg_project, - postprocess_configs, + postprocess_handler, experiment_ckpt_str, ) except Exception as e: @@ -174,7 +175,7 @@ def translate_files( src_iso: Optional[str], trg_iso: Optional[str], produce_multiple_translations: bool = False, - postprocess_configs: List[Dict[str, bool]] = [], + postprocess_handler: PostprocessHandler = PostprocessHandler(), ) -> None: translator, config, step_str = self._init_translation_task( experiment_suffix=f"_{self.checkpoint}_{os.path.basename(src)}", @@ -242,7 +243,7 @@ def translate_files( src_iso, trg_iso, produce_multiple_translations, - postprocess_configs, + postprocess_handler, experiment_ckpt_str=experiment_ckpt_str, ) @@ -382,6 +383,7 @@ def main() -> None: "include_embeds": args.include_embeds or args.include_inline_elements, } ] + postprocess_handler = PostprocessHandler(postprocess_configs) if len(args.books) > 0: if args.debug: @@ -393,7 +395,7 @@ def main() -> None: args.trg_project, args.trg_iso, args.multiple_translations, - postprocess_configs, + postprocess_handler, ) elif args.src_prefix is not None: if args.debug: @@ -419,7 +421,7 @@ def main() -> None: ) exit() translator.translate_files( - args.src, args.trg, args.src_iso, args.trg_iso, args.multiple_translations, postprocess_configs + args.src, args.trg, args.src_iso, args.trg_iso, args.multiple_translations, postprocess_handler ) else: raise RuntimeError("A Scripture book, file, or file prefix must be specified.") From d403b26c0146bbd486be30a35fc3c4e6bd113028 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Fri, 27 Jun 2025 20:54:17 +0000 Subject: [PATCH 07/10] Reorganize to resolve circular imports --- silnlp/common/compare_usfm_structure.py | 2 +- silnlp/common/postprocess_draft.py | 4 +- silnlp/common/postprocess_utils.py | 111 --------------- silnlp/common/postprocesser.py | 101 ++++++++++++++ silnlp/common/translator.py | 4 +- silnlp/common/usfm_utils.py | 11 +- silnlp/nmt/experiment.py | 2 +- silnlp/nmt/postprocess.py | 171 ++++++++++++------------ silnlp/nmt/translate.py | 2 +- 9 files changed, 200 insertions(+), 208 deletions(-) delete mode 100644 silnlp/common/postprocess_utils.py create mode 100644 silnlp/common/postprocesser.py diff --git a/silnlp/common/compare_usfm_structure.py b/silnlp/common/compare_usfm_structure.py index 3142bd43..6c6fd6ae 100644 --- a/silnlp/common/compare_usfm_structure.py +++ b/silnlp/common/compare_usfm_structure.py @@ -14,7 +14,7 @@ ) from machine.tokenization import WhitespaceTokenizer -from .usfm_preservation import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS +from .usfm_utils import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS LOGGER = logging.getLogger(__package__ + ".compare_usfm_structure") diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 5f1c381f..1f36c452 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -6,9 +6,9 @@ from ..nmt.clearml_connection import SILClearML from ..nmt.config_utils import load_config -from ..nmt.postprocess import PostprocessHandler, postprocess_draft +from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft from .paratext import get_project_dir -from .postprocess_utils import get_draft_paths_from_exp +from .postprocesser import PostprocessHandler from .utils import get_mt_exp_dir LOGGER = logging.getLogger(__package__ + ".postprocess_draft") diff --git a/silnlp/common/postprocess_utils.py b/silnlp/common/postprocess_utils.py deleted file mode 100644 index 5fa73311..00000000 --- a/silnlp/common/postprocess_utils.py +++ /dev/null @@ -1,111 +0,0 @@ -import logging -import re -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import List, Tuple - -import yaml -from machine.corpora import ScriptureRef, UsfmFileText, UsfmStylesheet, UsfmTextType -from machine.scripture import book_number_to_id, get_chapters -from machine.translation import WordAlignmentMatrix -from transformers.trainer_utils import get_last_checkpoint - -from ..alignment.eflomal import to_word_alignment_matrix -from ..alignment.utils import compute_alignment_scores -from ..nmt.config import Config -from ..nmt.hugging_face_config import get_best_checkpoint -from .corpus import load_corpus, write_corpus -from .paratext import book_file_name_digits, get_book_path - -LOGGER = logging.getLogger(__package__ + ".postprocess_utils") - -# Marker "type" is as defined by the UsfmTokenType given to tokens by the UsfmTokenizer, -# which mostly aligns with a marker's StyleType in the USFM stylesheet -CHARACTER_TYPE_EMBEDS = ["fig", "fm", "jmp", "rq", "va", "vp", "xt", "xtSee", "xtSeeAlso"] -PARAGRAPH_TYPE_EMBEDS = ["lit", "r", "rem"] -NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS - - -def get_alignment_matrices( - src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" -) -> List[WordAlignmentMatrix]: - with TemporaryDirectory() as td: - align_path = Path(td, "sym-align.txt") - write_corpus(Path(td, "src_align.txt"), src_sents) - write_corpus(Path(td, "trg_align.txt"), trg_sents) - compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path) - - return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] - - -# NOTE: to be replaced by new machine.py remark functionality -def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: - lines = usfm.split("\n") - remark_lines = [f"\\rem {r}" for r in remarks] - return "\n".join(lines[:1] + remark_lines + lines[1:]) - - -# Takes the path to a USFM file and the relevant info to parse it -# and returns the text of all non-embed sentences and their respective references, -# along with any remarks (\rem) that were inserted at the beginning of the file -def get_sentences( - book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = [] -) -> Tuple[List[str], List[ScriptureRef], List[str]]: - sents = [] - refs = [] - draft_remarks = [] - for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True): - marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else "" - if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s) - draft_remarks.append(sent.text) - continue - if ( - marker in PARAGRAPH_TYPE_EMBEDS - or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT - or (len(chapters) > 0 and sent.ref.chapter_num not in chapters) - ): - continue - - sents.append(re.sub(" +", " ", sent.text.strip())) - refs.append(sent.ref) - - return sents, refs, draft_remarks - - -# Get the paths of all drafts that would be produced by an experiment's translate config and that exist -def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]: - with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: - translate_requests = yaml.safe_load(file).get("translate", []) - - src_paths = [] - draft_paths = [] - for translate_request in translate_requests: - src_project = translate_request.get("src_project", next(iter(config.src_projects))) - - ckpt = translate_request.get("checkpoint", "last") - if ckpt == "best": - step_str = get_best_checkpoint(config.model_dir).name[11:] - elif ckpt == "last": - step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] - else: - step_str = str(ckpt) - - book_nums = get_chapters(translate_request.get("books", [])).keys() - for book_num in book_nums: - book = book_number_to_id(book_num) - - src_path = get_book_path(src_project, book) - draft_path = ( - config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" - ) - if draft_path.exists(): - src_paths.append(src_path) - draft_paths.append(draft_path) - elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts - for i in range(1, config.infer.get("num_drafts", 1) + 1): - src_paths.append(src_path) - draft_paths.append(draft_path.with_suffix(f".{i}{draft_path.suffix}")) - else: - LOGGER.warning(f"Draft not found: {draft_path}") - - return src_paths, draft_paths diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py new file mode 100644 index 00000000..eb60e5fa --- /dev/null +++ b/silnlp/common/postprocesser.py @@ -0,0 +1,101 @@ +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Dict, List, Union + +from machine.corpora import ( + PlaceMarkersAlignmentInfo, + PlaceMarkersUsfmUpdateBlockHandler, + ScriptureRef, + UpdateUsfmMarkerBehavior, + UsfmUpdateBlockHandler, +) +from machine.tokenization import LatinWordTokenizer +from machine.translation import WordAlignmentMatrix + +from ..alignment.eflomal import to_word_alignment_matrix +from ..alignment.utils import compute_alignment_scores +from .corpus import load_corpus, write_corpus +from .utils import merge_dict + +POSTPROCESS_OPTIONS = {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False} +POSTPROCESS_SUFFIX_CHARS = ["p", "s", "e"] + + +class PostprocessConfig: + update_block_handlers: List[UsfmUpdateBlockHandler] = [] + + def __init__(self, config: Dict[str, Union[bool, str]]) -> None: + # TODO: need to make a copy of the default dict? + self._config = merge_dict(POSTPROCESS_OPTIONS, config) + + def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior: + return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP + + def get_paragraph_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_paragraph_markers"]) + + def get_style_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_style_markers"]) + + def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior: + return self._get_usfm_marker_behavior(self._config["include_embeds"]) + + def get_postprocess_suffix(self) -> str: + suffix = "_" + for option, char in zip(POSTPROCESS_OPTIONS, POSTPROCESS_SUFFIX_CHARS): + if self._config[option]: + suffix += char + + return suffix if len(suffix) > 1 else "" + + def get_postprocess_remark(self) -> str: + return f"Post-processing options used: {' '.join(opt for opt in POSTPROCESS_OPTIONS if self._config[opt])}" + + +class PostprocessHandler: + def __init__(self, configs: List[Dict[str, Union[bool, str]]] = [], include_base: bool = True) -> None: + self.configs = [PostprocessConfig(config) for config in ([{}] if include_base else []) + configs] + + # NOTE: Update block handlers may need to be created/recreated at different times + # For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment, + # but other handlers may only need to be created once overall, or once per source project. + # This may change what part of the process we want this function to be called at + def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[str], translation: List[str]) -> None: + if any(config["include_paragraph_markers"] or config["include_style_markers"] for config in self.configs): + place_markers_handler = self._construct_place_markers_handler(refs, source, translation) + + for config in self.configs: + # TODO: make sure the configs are changing + if config["include_paragraph_markers"] or config["include_style_markers"]: + if len(config.update_block_handlers) == 0: + config.update_block_handlers.append(place_markers_handler) + else: # NOTE: this assumes a set order of update block handlers + config.update_block_handlers[0] = place_markers_handler + + def _construct_place_markers_handler( + self, refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal" + ) -> PlaceMarkersUsfmUpdateBlockHandler: + align_info = [] + tokenizer = LatinWordTokenizer() + alignments = self.get_alignment_matrices(source, translation, aligner) + for ref, s, t, alignment in zip(refs, source, translation, alignments): + align_info.append( + PlaceMarkersAlignmentInfo( + refs=[str(ref)], + source_tokens=list(tokenizer.tokenize(s)), + translation_tokens=list(tokenizer.tokenize(t)), + alignment=alignment, + ) + ) + return PlaceMarkersUsfmUpdateBlockHandler(align_info) + + def _get_alignment_matrices( + self, src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" + ) -> List[WordAlignmentMatrix]: + with TemporaryDirectory() as td: + align_path = Path(td, "sym-align.txt") + write_corpus(Path(td, "src_align.txt"), src_sents) + write_corpus(Path(td, "trg_align.txt"), trg_sents) + compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path) + + return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index f7f799f7..41357d92 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -22,10 +22,10 @@ ) from machine.scripture import VerseRef -from ..nmt.postprocess import PostprocessHandler from .corpus import load_corpus, write_corpus from .paratext import get_book_path, get_iso, get_project_dir -from .postprocess_utils import PARAGRAPH_TYPE_EMBEDS +from .postprocesser import PostprocessHandler +from .usfm_utils import PARAGRAPH_TYPE_EMBEDS LOGGER = logging.getLogger(__package__ + ".translate") nltk.download("punkt") diff --git a/silnlp/common/usfm_utils.py b/silnlp/common/usfm_utils.py index b715b6ec..7064b66c 100644 --- a/silnlp/common/usfm_utils.py +++ b/silnlp/common/usfm_utils.py @@ -2,6 +2,12 @@ from machine.corpora import FileParatextProjectSettingsParser, UsfmFileText, UsfmTokenizer, UsfmTokenType +# Marker "type" is as defined by the UsfmTokenType given to tokens by the UsfmTokenizer, +# which mostly aligns with a marker's StyleType in the USFM stylesheet +CHARACTER_TYPE_EMBEDS = ["fig", "fm", "jmp", "rq", "va", "vp", "xt", "xtSee", "xtSeeAlso"] +PARAGRAPH_TYPE_EMBEDS = ["lit", "r", "rem"] +NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS + def main() -> None: """ @@ -32,7 +38,7 @@ def main() -> None: with sentences_file.open("w", encoding=settings.encoding) as f: for sent in file_text: f.write(f"{sent}\n") - if len(sent.ref.path) > 0 and sent.ref.path[-1].name == "rem": + if len(sent.ref.path) > 0 and sent.ref.path[-1].name in PARAGRAPH_TYPE_EMBEDS: continue vrefs.append(sent.ref) @@ -40,13 +46,12 @@ def main() -> None: usfm_toks = usfm_tokenizer.tokenize(sent.text.strip()) ignore_scope = None - to_delete = ["fig"] for tok in usfm_toks: if ignore_scope is not None: if tok.type == UsfmTokenType.END and tok.marker[:-1] == ignore_scope.marker: ignore_scope = None elif tok.type == UsfmTokenType.NOTE or ( - tok.type == UsfmTokenType.CHARACTER and tok.marker in to_delete + tok.type == UsfmTokenType.CHARACTER and tok.marker in CHARACTER_TYPE_EMBEDS ): ignore_scope = tok elif tok.type in [UsfmTokenType.PARAGRAPH, UsfmTokenType.CHARACTER, UsfmTokenType.END]: diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index d2f632a1..702e89e8 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -7,10 +7,10 @@ import yaml from ..common.environment import SIL_NLP_ENV +from ..common.postprocesser import PostprocessHandler from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML from .config import Config, get_mt_exp_dir -from .postprocess import PostprocessHandler from .test import _SUPPORTED_SCORERS, test from .translate import TranslationTask diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index af412eab..5b323c91 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -1,109 +1,106 @@ import argparse import logging +import re from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import List, Optional, Tuple import yaml from machine.corpora import ( FileParatextProjectSettingsParser, - PlaceMarkersAlignmentInfo, - PlaceMarkersUsfmUpdateBlockHandler, ScriptureRef, - UpdateUsfmMarkerBehavior, UpdateUsfmParserHandler, UpdateUsfmTextBehavior, + UsfmFileText, UsfmStylesheet, - UsfmUpdateBlockHandler, + UsfmTextType, parse_usfm, ) -from machine.tokenization import LatinWordTokenizer - -from ..common.paratext import get_project_dir -from ..common.postprocess_utils import ( - get_alignment_matrices, - get_draft_paths_from_exp, - get_sentences, - insert_draft_remarks, -) -from ..common.utils import get_git_revision_hash, merge_dict +from machine.scripture import book_number_to_id, get_chapters +from transformers.trainer_utils import get_last_checkpoint + +from ..common.paratext import book_file_name_digits, get_book_path, get_project_dir +from ..common.postprocesser import PostprocessHandler +from ..common.usfm_utils import PARAGRAPH_TYPE_EMBEDS +from ..common.utils import get_git_revision_hash from .clearml_connection import SILClearML +from .config import Config from .config_utils import load_config +from .hugging_face_config import get_best_checkpoint LOGGER = logging.getLogger(__package__ + ".postprocess") -POSTPROCESS_OPTIONS = {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False} -POSTPROCESS_SUFFIX_CHARS = ["p", "s", "e"] - - -class PostprocessConfig: - update_block_handlers: List[UsfmUpdateBlockHandler] = [] - - def __init__(self, config: Dict[str, Union[bool, str]]) -> None: - # TODO: need to make a copy of the default dict? - self._config = merge_dict(POSTPROCESS_OPTIONS, config) - - def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior: - return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP - - def get_paragraph_behavior(self) -> UpdateUsfmMarkerBehavior: - return self._get_usfm_marker_behavior(self._config["include_paragraph_markers"]) - - def get_style_behavior(self) -> UpdateUsfmMarkerBehavior: - return self._get_usfm_marker_behavior(self._config["include_style_markers"]) - - def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior: - return self._get_usfm_marker_behavior(self._config["include_embeds"]) - - def get_postprocess_suffix(self) -> str: - suffix = "_" - for option, char in zip(POSTPROCESS_OPTIONS, POSTPROCESS_SUFFIX_CHARS): - if self._config[option]: - suffix += char - - return suffix if len(suffix) > 1 else "" - - def get_postprocess_remark(self) -> str: - return f"Post-processing options used: {' '.join(opt for opt in POSTPROCESS_OPTIONS if self._config[opt])}" - - -class PostprocessHandler: - # TODO: check if one of the configs is already all default? - def __init__(self, configs: List[Dict[str, Union[bool, str]]] = [], include_base: bool = True) -> None: - self.configs = [PostprocessConfig(config) for config in ([{}] if include_base else []) + configs] - - # NOTE: Update block handlers may need to be created/recreated at different times - # For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment, - # but other handlers may only need to be created once overall, or once per source project. - # This may change what part of the process we want this function to be called at - def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[str], translation: List[str]) -> None: - # USFM marker placement handler needs to be recreated for each draft - if any(config["include_paragraph_markers"] or config["include_style_markers"] for config in self.configs): - place_markers_handler = self._construct_place_markers_handler(refs, source, translation) - - for config in self.configs: - # TODO: make sure the configs are changing - if config["include_paragraph_markers"] or config["include_style_markers"]: - if len(config.update_block_handlers) == 0: - config.update_block_handlers.append(place_markers_handler) - else: # NOTE: this assumes a set order of update block handlers - config.update_block_handlers[0] = place_markers_handler - - def _construct_place_markers_handler( - self, refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal" - ) -> PlaceMarkersUsfmUpdateBlockHandler: - align_info = [] - tokenizer = LatinWordTokenizer() - alignments = get_alignment_matrices(source, translation, aligner) - for ref, s, t, alignment in zip(refs, source, translation, alignments): - align_info.append( - PlaceMarkersAlignmentInfo( - refs=[str(ref)], - source_tokens=list(tokenizer.tokenize(s)), - translation_tokens=list(tokenizer.tokenize(t)), - alignment=alignment, - ) + +# NOTE: to be replaced by new machine.py remark functionality +def insert_draft_remarks(usfm: str, remarks: List[str]) -> str: + lines = usfm.split("\n") + remark_lines = [f"\\rem {r}" for r in remarks] + return "\n".join(lines[:1] + remark_lines + lines[1:]) + + +# Takes the path to a USFM file and the relevant info to parse it +# and returns the text of all non-embed sentences and their respective references, +# along with any remarks (\rem) that were inserted at the beginning of the file +def get_sentences( + book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = [] +) -> Tuple[List[str], List[ScriptureRef], List[str]]: + sents = [] + refs = [] + draft_remarks = [] + for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True): + marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else "" + if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s) + draft_remarks.append(sent.text) + continue + if ( + marker in PARAGRAPH_TYPE_EMBEDS + or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT + or (len(chapters) > 0 and sent.ref.chapter_num not in chapters) + ): + continue + + sents.append(re.sub(" +", " ", sent.text.strip())) + refs.append(sent.ref) + + return sents, refs, draft_remarks + + +# Get the paths of all drafts that would be produced by an experiment's translate config and that exist +def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]: + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: + translate_requests = yaml.safe_load(file).get("translate", []) + + src_paths = [] + draft_paths = [] + for translate_request in translate_requests: + src_project = translate_request.get("src_project", next(iter(config.src_projects))) + + ckpt = translate_request.get("checkpoint", "last") + if ckpt == "best": + step_str = get_best_checkpoint(config.model_dir).name[11:] + elif ckpt == "last": + step_str = Path(get_last_checkpoint(config.model_dir)).name[11:] + else: + step_str = str(ckpt) + + book_nums = get_chapters(translate_request.get("books", [])).keys() + for book_num in book_nums: + book = book_number_to_id(book_num) + + src_path = get_book_path(src_project, book) + draft_path = ( + config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" ) - return PlaceMarkersUsfmUpdateBlockHandler(align_info) + if draft_path.exists(): + src_paths.append(src_path) + draft_paths.append(draft_path) + elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts + for i in range(1, config.infer.get("num_drafts", 1) + 1): + src_paths.append(src_path) + draft_paths.append(draft_path.with_suffix(f".{i}{draft_path.suffix}")) + else: + LOGGER.warning(f"Draft not found: {draft_path}") + + return src_paths, draft_paths def postprocess_draft( diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index e38bfdb4..1c876215 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -10,11 +10,11 @@ from ..common.environment import SIL_NLP_ENV from ..common.paratext import book_file_name_digits, get_project_dir +from ..common.postprocesser import PostprocessHandler from ..common.translator import TranslationGroup, Translator from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML from .config import CheckpointType, Config, NMTModel -from .postprocess import PostprocessHandler LOGGER = logging.getLogger(__package__ + ".translate") From 502d97992f199c14b921aaf710d66c98bde3d7ad Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Mon, 30 Jun 2025 23:43:53 +0000 Subject: [PATCH 08/10] Fix postprocessing bugs --- silnlp/common/postprocesser.py | 19 ++++++++++--------- silnlp/common/translator.py | 5 +++-- silnlp/nmt/postprocess.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index eb60e5fa..08bd2fbf 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -22,11 +22,9 @@ class PostprocessConfig: - update_block_handlers: List[UsfmUpdateBlockHandler] = [] - def __init__(self, config: Dict[str, Union[bool, str]]) -> None: - # TODO: need to make a copy of the default dict? - self._config = merge_dict(POSTPROCESS_OPTIONS, config) + self._config = merge_dict(dict(POSTPROCESS_OPTIONS), config) + self.update_block_handlers: List[UsfmUpdateBlockHandler] = [] def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior: return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP @@ -42,14 +40,18 @@ def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior: def get_postprocess_suffix(self) -> str: suffix = "_" - for option, char in zip(POSTPROCESS_OPTIONS, POSTPROCESS_SUFFIX_CHARS): - if self._config[option]: + for (option, default), char in zip(POSTPROCESS_OPTIONS.items(), POSTPROCESS_SUFFIX_CHARS): + if self._config[option] != default: suffix += char return suffix if len(suffix) > 1 else "" def get_postprocess_remark(self) -> str: - return f"Post-processing options used: {' '.join(opt for opt in POSTPROCESS_OPTIONS if self._config[opt])}" + used = [option for (option, default) in POSTPROCESS_OPTIONS.items() if self._config[option] != default] + return f"Post-processing options used: {' '.join(used)}" if len(used) > 0 else "" + + def __getitem__(self, key): + return self._config[key] class PostprocessHandler: @@ -65,7 +67,6 @@ def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[st place_markers_handler = self._construct_place_markers_handler(refs, source, translation) for config in self.configs: - # TODO: make sure the configs are changing if config["include_paragraph_markers"] or config["include_style_markers"]: if len(config.update_block_handlers) == 0: config.update_block_handlers.append(place_markers_handler) @@ -77,7 +78,7 @@ def _construct_place_markers_handler( ) -> PlaceMarkersUsfmUpdateBlockHandler: align_info = [] tokenizer = LatinWordTokenizer() - alignments = self.get_alignment_matrices(source, translation, aligner) + alignments = self._get_alignment_matrices(source, translation, aligner) for ref, s, t, alignment in zip(refs, source, translation, alignments): align_info.append( PlaceMarkersAlignmentInfo( diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 41357d92..0eb8411e 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -265,8 +265,9 @@ def translate_usfm( remarks = [ f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." ] - if len(config.get_postprocess_remark()) > 0: - remarks.append(config.get_postprocess_remark()) + postprocess_remark = config.get_postprocess_remark() + if len(postprocess_remark) > 0: + remarks.append(postprocess_remark) usfm_out = insert_draft_remarks(usfm_out, remarks) # Construct output file name write to file diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index 5b323c91..d5ab20e1 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -188,7 +188,7 @@ def main() -> None: with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: postprocess_configs = yaml.safe_load(file).get("postprocess", []) - postprocess_handler = PostprocessHandler(postprocess_configs) + postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) for src_path, draft_path in zip(src_paths, draft_paths): postprocess_draft(src_path, draft_path, postprocess_handler) From ff04368795e9a280ac19467187b15026932b69a9 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Tue, 1 Jul 2025 18:22:40 +0000 Subject: [PATCH 09/10] Remove ClearML CPU queue check --- silnlp/nmt/postprocess.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index d5ab20e1..e2f49fd0 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -176,8 +176,6 @@ def main() -> None: get_git_revision_hash() if args.clearml_queue is not None: - if "cpu" not in args.clearml_queue: - raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.") clearml = SILClearML(args.experiment, args.clearml_queue) config = clearml.config else: From 95107671e2c91162f435e27606411f8602dcad08 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Thu, 3 Jul 2025 19:38:21 +0000 Subject: [PATCH 10/10] Don't make PostprocessHandler dependent on config format --- silnlp/common/postprocess_draft.py | 4 ++-- silnlp/common/postprocesser.py | 6 +++--- silnlp/nmt/experiment.py | 6 ++++-- silnlp/nmt/postprocess.py | 4 ++-- silnlp/nmt/translate.py | 16 +++++++--------- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 1f36c452..b38faf89 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -8,7 +8,7 @@ from ..nmt.config_utils import load_config from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft from .paratext import get_project_dir -from .postprocesser import PostprocessHandler +from .postprocesser import PostprocessConfig, PostprocessHandler from .utils import get_mt_exp_dir LOGGER = logging.getLogger(__package__ + ".postprocess_draft") @@ -122,7 +122,7 @@ def main() -> None: else: LOGGER.info("Please use at least one postprocessing option.") exit() - postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) + postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False) if args.output_folder: args.output_folder = Path(args.output_folder.replace("\\", "/")) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 08bd2fbf..71004338 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -22,7 +22,7 @@ class PostprocessConfig: - def __init__(self, config: Dict[str, Union[bool, str]]) -> None: + def __init__(self, config: Dict[str, Union[bool, str]] = {}) -> None: self._config = merge_dict(dict(POSTPROCESS_OPTIONS), config) self.update_block_handlers: List[UsfmUpdateBlockHandler] = [] @@ -55,8 +55,8 @@ def __getitem__(self, key): class PostprocessHandler: - def __init__(self, configs: List[Dict[str, Union[bool, str]]] = [], include_base: bool = True) -> None: - self.configs = [PostprocessConfig(config) for config in ([{}] if include_base else []) + configs] + def __init__(self, configs: List[PostprocessConfig] = [], include_base: bool = True) -> None: + self.configs = ([PostprocessConfig()] if include_base else []) + configs # NOTE: Update block handlers may need to be created/recreated at different times # For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment, diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 702e89e8..ea186224 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -7,7 +7,7 @@ import yaml from ..common.environment import SIL_NLP_ENV -from ..common.postprocesser import PostprocessHandler +from ..common.postprocesser import PostprocessConfig, PostprocessHandler from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML from .config import Config, get_mt_exp_dir @@ -82,7 +82,9 @@ def translate(self): with (self.config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: translate_configs = yaml.safe_load(file) - postprocess_handler = PostprocessHandler(translate_configs.get("postprocess", [])) + postprocess_handler = PostprocessHandler( + [PostprocessConfig(pc) for pc in translate_configs.get("postprocess", [])] + ) for config in translate_configs.get("translate", []): translator = TranslationTask( diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index e2f49fd0..ff308024 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -19,7 +19,7 @@ from transformers.trainer_utils import get_last_checkpoint from ..common.paratext import book_file_name_digits, get_book_path, get_project_dir -from ..common.postprocesser import PostprocessHandler +from ..common.postprocesser import PostprocessConfig, PostprocessHandler from ..common.usfm_utils import PARAGRAPH_TYPE_EMBEDS from ..common.utils import get_git_revision_hash from .clearml_connection import SILClearML @@ -186,7 +186,7 @@ def main() -> None: with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: postprocess_configs = yaml.safe_load(file).get("postprocess", []) - postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) + postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False) for src_path, draft_path in zip(src_paths, draft_paths): postprocess_draft(src_path, draft_path, postprocess_handler) diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 1c876215..82f02f1a 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -10,7 +10,7 @@ from ..common.environment import SIL_NLP_ENV from ..common.paratext import book_file_name_digits, get_project_dir -from ..common.postprocesser import PostprocessHandler +from ..common.postprocesser import PostprocessConfig, PostprocessHandler from ..common.translator import TranslationGroup, Translator from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import SILClearML @@ -376,14 +376,12 @@ def main() -> None: name=args.experiment, checkpoint=args.checkpoint, clearml_queue=args.clearml_queue, commit=args.commit ) - postprocess_configs = [ - { - "include_paragraph_markers": args.include_paragraph_markers or args.preserve_usfm_markers, - "include_style_markers": args.include_style_markers or args.preserve_usfm_markers, - "include_embeds": args.include_embeds or args.include_inline_elements, - } - ] - postprocess_handler = PostprocessHandler(postprocess_configs) + postprocess_config = { + "include_paragraph_markers": args.include_paragraph_markers or args.preserve_usfm_markers, + "include_style_markers": args.include_style_markers or args.preserve_usfm_markers, + "include_embeds": args.include_embeds or args.include_inline_elements, + } + postprocess_handler = PostprocessHandler([PostprocessConfig(postprocess_config)]) if len(args.books) > 0: if args.debug: