Skip to content
205 changes: 127 additions & 78 deletions silnlp/common/postprocess_draft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

Expand All @@ -16,7 +17,7 @@
UsfmTextType,
parse_usfm,
)
from machine.scripture import book_id_to_number
from machine.scripture import book_number_to_id, get_chapters
from transformers.trainer_utils import get_last_checkpoint

from ..nmt.clearml_connection import SILClearML
Expand All @@ -25,36 +26,40 @@
from ..nmt.hugging_face_config import get_best_checkpoint
from .paratext import book_file_name_digits, get_book_path, get_project_dir
from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler
from .utils import get_mt_exp_dir
from .utils import get_mt_exp_dir, merge_dict

LOGGER = logging.getLogger(__package__ + ".postprocess_draft")


# NOTE: only using first book of first translate request for now
def get_paths_from_exp(config: Config) -> Tuple[Path, Path]:
# TODO: default to first draft in the infer folder
def get_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]:
if not (config.exp_dir / "translate_config.yml").exists():
raise ValueError("Experiment translate_config.yml not found. Please use --source and --draft options instead.")

with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
translate_config = yaml.safe_load(file)["translate"][0]
src_project = translate_config.get("src_project", next(iter(config.src_projects)))
books = translate_config["books"]
book = books[0] if isinstance(books, list) else books.split(";")[0] # TODO: handle partial book translation
book_num = book_id_to_number(book)

ckpt = translate_config.get("checkpoint", "last")
if ckpt == "best":
step_str = get_best_checkpoint(config.model_dir).name[11:]
elif ckpt == "last":
step_str = Path(get_last_checkpoint(config.model_dir)).name[11:]
else:
step_str = str(ckpt)
translate_requests = yaml.safe_load(file).get("translate", [])

return (
get_book_path(src_project, book),
config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM",
)
src_paths = []
draft_paths = []
for translate_request in translate_requests:
src_project = translate_request.get("src_project", next(iter(config.src_projects)))

ckpt = translate_request.get("checkpoint", "last")
if ckpt == "best":
step_str = get_best_checkpoint(config.model_dir).name[11:]
elif ckpt == "last":
step_str = Path(get_last_checkpoint(config.model_dir)).name[11:]
else:
step_str = str(ckpt)

book_nums = get_chapters(translate_request.get("books", [])).keys()
for book_num in book_nums:
book = book_number_to_id(book_num)

src_paths.append(get_book_path(src_project, book))
draft_paths.append(
config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM"
)

return src_paths, draft_paths


def insert_draft_remarks(usfm: str, remarks: List[str]) -> str:
Expand Down Expand Up @@ -89,9 +94,7 @@ def get_sentences(


def main() -> None:
parser = argparse.ArgumentParser(
description="Applies draft postprocessing steps to a draft. Can be used with no postprocessing options to create a base draft."
)
parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.")
parser.add_argument(
"--experiment",
default=None,
Expand Down Expand Up @@ -166,68 +169,114 @@ def main() -> None:
config = yaml.safe_load(file)
config = create_config(exp_dir, config)

src_path, draft_path = get_paths_from_exp(config)
src_paths, draft_paths = get_paths_from_exp(config)
elif args.clearml_queue is not None:
raise ValueError("Must use --experiment option to use ClearML.")
else:
src_path = Path(args.source.replace("\\", "/"))
draft_path = Path(args.draft.replace("\\", "/"))

if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
encoding = settings.encoding
book = settings.get_book_id(src_path.name)
src_paths = [Path(args.source.replace("\\", "/"))]
draft_paths = [Path(args.draft.replace("\\", "/"))]

# If no postprocessing options are used, use any postprocessing requests in the experiment's translate config
if args.include_paragraph_markers or args.include_style_markers or args.include_embeds:
postprocess_configs = [
{
"include_paragraph_markers": args.include_paragraph_markers,
"include_style_markers": args.include_style_markers,
"include_embeds": args.include_embeds,
}
]
else:
stylesheet = UsfmStylesheet("usfm.sty")
encoding = "utf-8-sig"
book = args.book
if book is None:
raise ValueError(
"--book argument must be passed if the source file is not in a Paratext project directory."
)
if args.experiment:
LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.")
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
postprocess_configs = yaml.safe_load(file).get("postprocess", [])
postprocess_configs = [merge_dict(defaultdict(lambda: False), ppc) for ppc in postprocess_configs]
if len(postprocess_configs) == 0:
LOGGER.info("No postprocessing requests found.")
exit()
else:
LOGGER.info("Please use at least one postprocessing option.")
exit()

src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book)
draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book)
for src_path, draft_path in zip(src_paths, draft_paths):
if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
encoding = settings.encoding
book = settings.get_book_id(src_path.name)
else:
stylesheet = UsfmStylesheet("usfm.sty")
encoding = "utf-8-sig"
book = args.book
if book is None:
raise ValueError(
"--book argument must be passed if the source file is not in a Paratext project directory."
)

src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book)
draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book)

if len(src_refs) != len(draft_refs):
LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references")
continue
for src_ref, draft_ref in zip(src_refs, draft_refs):
if src_ref.to_relaxed() != draft_ref.to_relaxed():
LOGGER.warning(
f"Can't process {src_path} and {draft_path}: Mismatched ref: {src_ref} {draft_ref}. Files must have the exact same USFM structure"
)
continue

if any(ppc["include_paragraph_markers"] or ppc["include_style_markers"] for ppc in postprocess_configs):
place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents)

for postprocess_config in postprocess_configs:
update_block_handlers = []
if postprocess_config["include_paragraph_markers"] or postprocess_config["include_style_markers"]:
update_block_handlers.append(place_markers_handler)

if len(src_refs) != len(draft_refs):
raise ValueError("Different number of verses/references between source and draft.")
for src_ref, draft_ref in zip(src_refs, draft_refs):
if src_ref.to_relaxed() != draft_ref.to_relaxed():
raise ValueError(
f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}"
paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config["include_paragraph_markers"]
else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config["include_style_markers"]
else UpdateUsfmMarkerBehavior.STRIP
)
embed_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config["include_embeds"]
else UpdateUsfmMarkerBehavior.STRIP
)
marker_placement_suffix = (
"_"
+ ("p" if postprocess_config["include_paragraph_markers"] else "")
+ ("s" if postprocess_config["include_style_markers"] else "")
+ ("e" if postprocess_config["include_embeds"] else "")
)

paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if args.include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_embeds else UpdateUsfmMarkerBehavior.STRIP

update_block_handlers = []
if args.include_paragraph_markers or args.include_style_markers:
update_block_handlers.append(construct_place_markers_handler(src_refs, src_sents, draft_sents))

with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()
with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

usfm_out = insert_draft_remarks(usfm_out, draft_remarks)
# TODO: back a level
usfm_out = insert_draft_remarks(usfm_out, draft_remarks)

out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}_postprocessed{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)
out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)


if __name__ == "__main__":
Expand Down
Loading