Skip to content

Add support for multiple postprocessing requests #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 7, 2025
2 changes: 1 addition & 1 deletion silnlp/common/compare_usfm_structure.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
)
from machine.tokenization import WhitespaceTokenizer

from .usfm_preservation import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS
from .usfm_utils import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS

LOGGER = logging.getLogger(__package__ + ".compare_usfm_structure")

180 changes: 40 additions & 140 deletions silnlp/common/postprocess_draft.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,21 @@
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple

import yaml
from machine.corpora import (
FileParatextProjectSettingsParser,
ScriptureRef,
UpdateUsfmMarkerBehavior,
UpdateUsfmParserHandler,
UpdateUsfmTextBehavior,
UsfmFileText,
UsfmStylesheet,
UsfmTextType,
parse_usfm,
)
from machine.scripture import book_id_to_number
from transformers.trainer_utils import get_last_checkpoint

from ..nmt.clearml_connection import SILClearML
from ..nmt.config import Config
from ..nmt.config_utils import create_config
from ..nmt.hugging_face_config import get_best_checkpoint
from .paratext import book_file_name_digits, get_book_path, get_project_dir
from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler
from ..nmt.config_utils import load_config
from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft
from .paratext import get_project_dir
from .postprocesser import PostprocessConfig, PostprocessHandler
from .utils import get_mt_exp_dir

LOGGER = logging.getLogger(__package__ + ".postprocess_draft")


# NOTE: only using first book of first translate request for now
def get_paths_from_exp(config: Config) -> Tuple[Path, Path]:
# TODO: default to first draft in the infer folder
if not (config.exp_dir / "translate_config.yml").exists():
raise ValueError("Experiment translate_config.yml not found. Please use --source and --draft options instead.")

with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
translate_config = yaml.safe_load(file)["translate"][0]
src_project = translate_config.get("src_project", next(iter(config.src_projects)))
books = translate_config["books"]
book = books[0] if isinstance(books, list) else books.split(";")[0] # TODO: handle partial book translation
book_num = book_id_to_number(book)

ckpt = translate_config.get("checkpoint", "last")
if ckpt == "best":
step_str = get_best_checkpoint(config.model_dir).name[11:]
elif ckpt == "last":
step_str = Path(get_last_checkpoint(config.model_dir)).name[11:]
else:
step_str = str(ckpt)

return (
get_book_path(src_project, book),
config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM",
)


def insert_draft_remarks(usfm: str, remarks: List[str]) -> str:
lines = usfm.split("\n")
remark_lines = [f"\\rem {r}" for r in remarks]
return "\n".join(lines[:1] + remark_lines + lines[1:])


def get_sentences(
book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = []
) -> Tuple[List[str], List[ScriptureRef], List[str]]:
sents = []
refs = []
draft_remarks = []
for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True):
marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else ""
if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s)
draft_remarks.append(sent.text)
continue
if (
marker in PARAGRAPH_TYPE_EMBEDS
or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT
# or len(sent.text.strip()) == 0
or (len(chapters) > 0 and sent.ref.chapter_num not in chapters)
):
continue

sents.append(re.sub(" +", " ", sent.text.strip()))
refs.append(sent.ref)

return sents, refs, draft_remarks


def main() -> None:
parser = argparse.ArgumentParser(
description="Applies draft postprocessing steps to a draft. Can be used with no postprocessing options to create a base draft."
)
parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.")
parser.add_argument(
"--experiment",
default=None,
@@ -155,79 +79,55 @@ def main() -> None:

experiment = args.experiment.replace("\\", "/") if args.experiment else None
if experiment and get_mt_exp_dir(experiment).exists():
exp_dir = get_mt_exp_dir(experiment)
if args.clearml_queue is not None:
if "cpu" not in args.clearml_queue:
raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.")
clearml = SILClearML(experiment, args.clearml_queue)
config = clearml.config
else:
with (exp_dir / "config.yml").open("r", encoding="utf-8") as file:
config = yaml.safe_load(file)
config = create_config(exp_dir, config)
config = load_config(experiment)

src_path, draft_path = get_paths_from_exp(config)
if not (config.exp_dir / "translate_config.yml").exists():
raise ValueError(
"Experiment translate_config.yml not found. Please use --source and --draft options instead."
)
src_paths, draft_paths = get_draft_paths_from_exp(config)
elif args.clearml_queue is not None:
raise ValueError("Must use --experiment option to use ClearML.")
else:
src_path = Path(args.source.replace("\\", "/"))
draft_path = Path(args.draft.replace("\\", "/"))

if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
encoding = settings.encoding
book = settings.get_book_id(src_path.name)
else:
stylesheet = UsfmStylesheet("usfm.sty")
encoding = "utf-8-sig"
book = args.book
if book is None:
src_paths = [Path(args.source.replace("\\", "/"))]
draft_paths = [Path(args.draft.replace("\\", "/"))]
if not str(src_paths[0]).startswith(str(get_project_dir(""))) and args.book is None:
raise ValueError(
"--book argument must be passed if the source file is not in a Paratext project directory."
)

src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book)
draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book)

if len(src_refs) != len(draft_refs):
raise ValueError("Different number of verses/references between source and draft.")
for src_ref, draft_ref in zip(src_refs, draft_refs):
if src_ref.to_relaxed() != draft_ref.to_relaxed():
raise ValueError(
f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}"
)

paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if args.include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_embeds else UpdateUsfmMarkerBehavior.STRIP

update_block_handlers = []
if args.include_paragraph_markers or args.include_style_markers:
update_block_handlers.append(construct_place_markers_handler(src_refs, src_sents, draft_sents))

with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

usfm_out = insert_draft_remarks(usfm_out, draft_remarks)

out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}_postprocessed{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)
# If no postprocessing options are used, use any postprocessing requests in the experiment's translate config
if args.include_paragraph_markers or args.include_style_markers or args.include_embeds:
postprocess_configs = [
{
"include_paragraph_markers": args.include_paragraph_markers,
"include_style_markers": args.include_style_markers,
"include_embeds": args.include_embeds,
}
]
else:
if args.experiment:
LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.")
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
postprocess_configs = yaml.safe_load(file).get("postprocess", [])
if len(postprocess_configs) == 0:
LOGGER.info("No postprocessing requests found in translate config.")
exit()
else:
LOGGER.info("Please use at least one postprocessing option.")
exit()
postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False)

if args.output_folder:
args.output_folder = Path(args.output_folder.replace("\\", "/"))
for src_path, draft_path in zip(src_paths, draft_paths):
postprocess_draft(src_path, draft_path, postprocess_handler, args.book, args.output_folder)


if __name__ == "__main__":
102 changes: 102 additions & 0 deletions silnlp/common/postprocesser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List, Union

from machine.corpora import (
PlaceMarkersAlignmentInfo,
PlaceMarkersUsfmUpdateBlockHandler,
ScriptureRef,
UpdateUsfmMarkerBehavior,
UsfmUpdateBlockHandler,
)
from machine.tokenization import LatinWordTokenizer
from machine.translation import WordAlignmentMatrix

from ..alignment.eflomal import to_word_alignment_matrix
from ..alignment.utils import compute_alignment_scores
from .corpus import load_corpus, write_corpus
from .utils import merge_dict

POSTPROCESS_OPTIONS = {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False}
POSTPROCESS_SUFFIX_CHARS = ["p", "s", "e"]


class PostprocessConfig:
def __init__(self, config: Dict[str, Union[bool, str]] = {}) -> None:
self._config = merge_dict(dict(POSTPROCESS_OPTIONS), config)
self.update_block_handlers: List[UsfmUpdateBlockHandler] = []

def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior:
return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP

def get_paragraph_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_paragraph_markers"])

def get_style_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_style_markers"])

def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_embeds"])

def get_postprocess_suffix(self) -> str:
suffix = "_"
for (option, default), char in zip(POSTPROCESS_OPTIONS.items(), POSTPROCESS_SUFFIX_CHARS):
if self._config[option] != default:
suffix += char

return suffix if len(suffix) > 1 else ""

def get_postprocess_remark(self) -> str:
used = [option for (option, default) in POSTPROCESS_OPTIONS.items() if self._config[option] != default]
return f"Post-processing options used: {' '.join(used)}" if len(used) > 0 else ""

def __getitem__(self, key):
return self._config[key]


class PostprocessHandler:
def __init__(self, configs: List[PostprocessConfig] = [], include_base: bool = True) -> None:
self.configs = ([PostprocessConfig()] if include_base else []) + configs

# NOTE: Update block handlers may need to be created/recreated at different times
# For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment,
# but other handlers may only need to be created once overall, or once per source project.
# This may change what part of the process we want this function to be called at
def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[str], translation: List[str]) -> None:
if any(config["include_paragraph_markers"] or config["include_style_markers"] for config in self.configs):
place_markers_handler = self._construct_place_markers_handler(refs, source, translation)

for config in self.configs:
if config["include_paragraph_markers"] or config["include_style_markers"]:
if len(config.update_block_handlers) == 0:
config.update_block_handlers.append(place_markers_handler)
else: # NOTE: this assumes a set order of update block handlers
config.update_block_handlers[0] = place_markers_handler

def _construct_place_markers_handler(
self, refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal"
) -> PlaceMarkersUsfmUpdateBlockHandler:
align_info = []
tokenizer = LatinWordTokenizer()
alignments = self._get_alignment_matrices(source, translation, aligner)
for ref, s, t, alignment in zip(refs, source, translation, alignments):
align_info.append(
PlaceMarkersAlignmentInfo(
refs=[str(ref)],
source_tokens=list(tokenizer.tokenize(s)),
translation_tokens=list(tokenizer.tokenize(t)),
alignment=alignment,
)
)
return PlaceMarkersUsfmUpdateBlockHandler(align_info)

def _get_alignment_matrices(
self, src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal"
) -> List[WordAlignmentMatrix]:
with TemporaryDirectory() as td:
align_path = Path(td, "sym-align.txt")
write_corpus(Path(td, "src_align.txt"), src_sents)
write_corpus(Path(td, "trg_align.txt"), trg_sents)
compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path)

return [to_word_alignment_matrix(line) for line in load_corpus(align_path)]
132 changes: 63 additions & 69 deletions silnlp/common/translator.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@
from machine.corpora import (
FileParatextProjectSettingsParser,
FileParatextProjectTextUpdater,
UpdateUsfmMarkerBehavior,
UpdateUsfmParserHandler,
UpdateUsfmTextBehavior,
UsfmFileText,
@@ -25,28 +24,22 @@

from .corpus import load_corpus, write_corpus
from .paratext import get_book_path, get_iso, get_project_dir
from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler
from .postprocesser import PostprocessHandler
from .usfm_utils import PARAGRAPH_TYPE_EMBEDS

LOGGER = logging.getLogger(__package__ + ".translate")
nltk.download("punkt")


def insert_draft_remark(
usfm: str,
book: str,
description: str,
experiment_ckpt_str: str,
) -> str:
remark = f"\\rem This draft of {book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully."

def insert_draft_remarks(usfm: str, remarks: List[str]) -> str:
lines = usfm.split("\n")
insert_idx = (
1
+ (len(lines) > 1 and (lines[1].startswith("\\ide") or lines[1].startswith("\\usfm")))
+ (len(lines) > 2 and (lines[2].startswith("\\ide") or lines[2].startswith("\\usfm")))
)
lines.insert(insert_idx, remark)
return "\n".join(lines)
remarks = [f"\\rem {r}" for r in remarks]
return "\n".join(lines[:insert_idx] + remarks + lines[insert_idx:])


# A group of multiple translations of a single sentence
@@ -137,9 +130,7 @@ def translate_book(
produce_multiple_translations: bool = False,
chapters: List[int] = [],
trg_project: Optional[str] = None,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_handler: PostprocessHandler = PostprocessHandler(),
experiment_ckpt_str: str = "",
) -> None:
book_path = get_book_path(src_project, book)
@@ -156,9 +147,7 @@ def translate_book(
produce_multiple_translations,
chapters,
trg_project,
include_paragraph_markers,
include_style_markers,
include_embeds,
postprocess_handler,
experiment_ckpt_str,
)

@@ -171,9 +160,7 @@ def translate_usfm(
produce_multiple_translations: bool = False,
chapters: List[int] = [],
trg_project: Optional[str] = None,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_handler: PostprocessHandler = PostprocessHandler(),
experiment_ckpt_str: str = "",
) -> None:
# Create UsfmFileText object for source
@@ -226,72 +213,79 @@ def translate_usfm(
vrefs.insert(idx, vref)
output.insert(idx, [None, None, None, None])

# Update behaviors
text_behavior = (
UpdateUsfmTextBehavior.PREFER_NEW if trg_project is not None else UpdateUsfmTextBehavior.STRIP_EXISTING
)
paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_embeds else UpdateUsfmMarkerBehavior.STRIP

draft_set: DraftGroup = DraftGroup(translations)
for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1):
rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)]

update_block_handlers = []
if include_paragraph_markers or include_style_markers:
update_block_handlers.append(construct_place_markers_handler(vrefs, sentences, translated_draft))
postprocess_handler.create_update_block_handlers(vrefs, sentences, translated_draft)

# Insert translation into the USFM structure of an existing project
# If the target project is not the same as the translated file's original project,
# no verses outside of the ones translated will be overwritten
if trg_project is not None or src_from_project:
dest_updater = FileParatextProjectTextUpdater(
get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name)
)
usfm_out = dest_updater.update_usfm(
book_id=src_file_text.id,
rows=rows,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
for config in postprocess_handler.configs:
# Insert translation into the USFM structure of an existing project
# If the target project is not the same as the translated file's original project,
# no verses outside of the ones translated will be overwritten
if trg_project is not None or src_from_project:
dest_updater = FileParatextProjectTextUpdater(
get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name)
)
usfm_out = dest_updater.update_usfm(
book_id=src_file_text.id,
rows=rows,
text_behavior=text_behavior,
paragraph_behavior=config.get_paragraph_behavior(),
embed_behavior=config.get_embed_behavior(),
style_behavior=config.get_style_behavior(),
update_block_handlers=config.update_block_handlers,
)

if usfm_out is None:
raise FileNotFoundError(f"Book {src_file_text.id} does not exist in target project {trg_project}")
else: # Slightly more manual version for updating an individual file
with open(src_file_path, encoding="utf-8-sig") as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=rows,
id_text=vrefs[0].book,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()
if usfm_out is None:
raise FileNotFoundError(
f"Book {src_file_text.id} does not exist in target project {trg_project}"
)
else: # Slightly more manual version for updating an individual file
with open(src_file_path, encoding="utf-8-sig") as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=rows,
id_text=vrefs[0].book,
text_behavior=text_behavior,
paragraph_behavior=config.get_paragraph_behavior(),
embed_behavior=config.get_embed_behavior(),
style_behavior=config.get_style_behavior(),
update_block_handlers=config.update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

# Insert draft remarks
description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}"
remarks = [
f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully."
]
postprocess_remark = config.get_postprocess_remark()
if len(postprocess_remark) > 0:
remarks.append(postprocess_remark)
usfm_out = insert_draft_remarks(usfm_out, remarks)

# Construct output file name write to file
trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + config.get_postprocess_suffix())
if produce_multiple_translations:
trg_draft_file_path = trg_draft_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}")
with trg_draft_file_path.open(
"w", encoding=src_settings.encoding if src_from_project else "utf-8"
) as f:
f.write(usfm_out)

# Insert draft remark and write to output path
description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}"
usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str)
confidence_scores_suffix = ".confidences.tsv"
if produce_multiple_translations:
trg_draft_file_path = trg_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}")
confidences_path = trg_file_path.with_suffix(
f".{draft_index}{trg_file_path.suffix}{confidence_scores_suffix}"
)
else:
trg_draft_file_path = trg_file_path
confidences_path = trg_file_path.with_suffix(f"{trg_file_path.suffix}{confidence_scores_suffix}")
with trg_draft_file_path.open("w", encoding=src_settings.encoding if src_from_project else "utf-8") as f:
f.write(usfm_out)
with confidences_path.open("w", encoding="utf-8", newline="\n") as confidences_file:
confidences_file.write("\t".join(["VRef"] + [f"Token {i}" for i in range(200)]) + "\n")
confidences_file.write("\t".join(["Sequence Score"] + [f"Token Score {i}" for i in range(200)]) + "\n")
48 changes: 0 additions & 48 deletions silnlp/common/usfm_preservation.py

This file was deleted.

11 changes: 8 additions & 3 deletions silnlp/common/usfm_utils.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,12 @@

from machine.corpora import FileParatextProjectSettingsParser, UsfmFileText, UsfmTokenizer, UsfmTokenType

# Marker "type" is as defined by the UsfmTokenType given to tokens by the UsfmTokenizer,
# which mostly aligns with a marker's StyleType in the USFM stylesheet
CHARACTER_TYPE_EMBEDS = ["fig", "fm", "jmp", "rq", "va", "vp", "xt", "xtSee", "xtSeeAlso"]
PARAGRAPH_TYPE_EMBEDS = ["lit", "r", "rem"]
NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS


def main() -> None:
"""
@@ -32,21 +38,20 @@ def main() -> None:
with sentences_file.open("w", encoding=settings.encoding) as f:
for sent in file_text:
f.write(f"{sent}\n")
if len(sent.ref.path) > 0 and sent.ref.path[-1].name == "rem":
if len(sent.ref.path) > 0 and sent.ref.path[-1].name in PARAGRAPH_TYPE_EMBEDS:
continue

vrefs.append(sent.ref)
usfm_markers.append([])
usfm_toks = usfm_tokenizer.tokenize(sent.text.strip())

ignore_scope = None
to_delete = ["fig"]
for tok in usfm_toks:
if ignore_scope is not None:
if tok.type == UsfmTokenType.END and tok.marker[:-1] == ignore_scope.marker:
ignore_scope = None
elif tok.type == UsfmTokenType.NOTE or (
tok.type == UsfmTokenType.CHARACTER and tok.marker in to_delete
tok.type == UsfmTokenType.CHARACTER and tok.marker in CHARACTER_TYPE_EMBEDS
):
ignore_scope = tok
elif tok.type in [UsfmTokenType.PARAGRAPH, UsfmTokenType.CHARACTER, UsfmTokenType.END]:
13 changes: 7 additions & 6 deletions silnlp/nmt/experiment.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import yaml

from ..common.environment import SIL_NLP_ENV
from ..common.postprocesser import PostprocessConfig, PostprocessHandler
from ..common.utils import get_git_revision_hash, show_attrs
from .clearml_connection import SILClearML
from .config import Config, get_mt_exp_dir
@@ -81,6 +82,10 @@ def translate(self):
with (self.config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
translate_configs = yaml.safe_load(file)

postprocess_handler = PostprocessHandler(
[PostprocessConfig(pc) for pc in translate_configs.get("postprocess", [])]
)

for config in translate_configs.get("translate", []):
translator = TranslationTask(
name=self.name, checkpoint=config.get("checkpoint", "last"), commit=self.commit
@@ -95,9 +100,7 @@ def translate(self):
config.get("trg_project"),
config.get("trg_iso"),
self.produce_multiple_translations,
config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_embeds", False) or config.get("include_inline_elements", False),
postprocess_handler,
)
elif config.get("src_prefix"):
translator.translate_text_files(
@@ -116,9 +119,7 @@ def translate(self):
config.get("src_iso"),
config.get("trg_iso"),
self.produce_multiple_translations,
config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_embeds", False) or config.get("include_inline_elements", False),
postprocess_handler,
)
else:
raise RuntimeError("A Scripture book, file, or file prefix must be specified for translation.")
196 changes: 196 additions & 0 deletions silnlp/nmt/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import argparse
import logging
import re
from pathlib import Path
from typing import List, Optional, Tuple

import yaml
from machine.corpora import (
FileParatextProjectSettingsParser,
ScriptureRef,
UpdateUsfmParserHandler,
UpdateUsfmTextBehavior,
UsfmFileText,
UsfmStylesheet,
UsfmTextType,
parse_usfm,
)
from machine.scripture import book_number_to_id, get_chapters
from transformers.trainer_utils import get_last_checkpoint

from ..common.paratext import book_file_name_digits, get_book_path, get_project_dir
from ..common.postprocesser import PostprocessConfig, PostprocessHandler
from ..common.usfm_utils import PARAGRAPH_TYPE_EMBEDS
from ..common.utils import get_git_revision_hash
from .clearml_connection import SILClearML
from .config import Config
from .config_utils import load_config
from .hugging_face_config import get_best_checkpoint

LOGGER = logging.getLogger(__package__ + ".postprocess")


# NOTE: to be replaced by new machine.py remark functionality
def insert_draft_remarks(usfm: str, remarks: List[str]) -> str:
lines = usfm.split("\n")
remark_lines = [f"\\rem {r}" for r in remarks]
return "\n".join(lines[:1] + remark_lines + lines[1:])


# Takes the path to a USFM file and the relevant info to parse it
# and returns the text of all non-embed sentences and their respective references,
# along with any remarks (\rem) that were inserted at the beginning of the file
def get_sentences(
book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = []
) -> Tuple[List[str], List[ScriptureRef], List[str]]:
sents = []
refs = []
draft_remarks = []
for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True):
marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else ""
if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s)
draft_remarks.append(sent.text)
continue
if (
marker in PARAGRAPH_TYPE_EMBEDS
or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT
or (len(chapters) > 0 and sent.ref.chapter_num not in chapters)
):
continue

sents.append(re.sub(" +", " ", sent.text.strip()))
refs.append(sent.ref)

return sents, refs, draft_remarks


# Get the paths of all drafts that would be produced by an experiment's translate config and that exist
def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path]]:
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
translate_requests = yaml.safe_load(file).get("translate", [])

src_paths = []
draft_paths = []
for translate_request in translate_requests:
src_project = translate_request.get("src_project", next(iter(config.src_projects)))

ckpt = translate_request.get("checkpoint", "last")
if ckpt == "best":
step_str = get_best_checkpoint(config.model_dir).name[11:]
elif ckpt == "last":
step_str = Path(get_last_checkpoint(config.model_dir)).name[11:]
else:
step_str = str(ckpt)

book_nums = get_chapters(translate_request.get("books", [])).keys()
for book_num in book_nums:
book = book_number_to_id(book_num)

src_path = get_book_path(src_project, book)
draft_path = (
config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM"
)
if draft_path.exists():
src_paths.append(src_path)
draft_paths.append(draft_path)
elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts
for i in range(1, config.infer.get("num_drafts", 1) + 1):
src_paths.append(src_path)
draft_paths.append(draft_path.with_suffix(f".{i}{draft_path.suffix}"))
else:
LOGGER.warning(f"Draft not found: {draft_path}")

return src_paths, draft_paths


def postprocess_draft(
src_path: Path,
draft_path: Path,
postprocess_handler: PostprocessHandler,
book: Optional[str] = None,
out_dir: Optional[Path] = None,
) -> None:
if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
encoding = settings.encoding
book = settings.get_book_id(src_path.name)
else:
stylesheet = UsfmStylesheet("usfm.sty")
encoding = "utf-8-sig"

src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book)
draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book)

# Verify reference parity
if len(src_refs) != len(draft_refs):
LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references")
return
for src_ref, draft_ref in zip(src_refs, draft_refs):
if src_ref.to_relaxed() != draft_ref.to_relaxed():
LOGGER.warning(
f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure"
)
return

postprocess_handler.create_update_block_handlers(src_refs, src_sents, draft_sents)

with src_path.open(encoding=encoding) as f:
usfm = f.read()
rows = [([ref], sent) for ref, sent in zip(src_refs, draft_sents)]

for config in postprocess_handler.configs:
handler = UpdateUsfmParserHandler(
rows=rows,
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=config.get_paragraph_behavior(),
embed_behavior=config.get_embed_behavior(),
style_behavior=config.get_style_behavior(),
update_block_handlers=config.update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

usfm_out = insert_draft_remarks(usfm_out, draft_remarks + [config.get_postprocess_remark()])

if not out_dir:
out_dir = draft_path.parent
out_path = out_dir / f"{draft_path.stem}{config.get_postprocess_suffix()}{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)


def main() -> None:
parser = argparse.ArgumentParser(description="Postprocess the drafts created by an NMT model")
parser.add_argument("experiment", help="Experiment name")
parser.add_argument(
"--clearml-queue",
default=None,
type=str,
help="Run remotely on ClearML queue. Default: None - don't register with ClearML. The queue 'local' will run "
+ "it locally and register it with ClearML.",
)
args = parser.parse_args()

get_git_revision_hash()

if args.clearml_queue is not None:
clearml = SILClearML(args.experiment, args.clearml_queue)
config = clearml.config
else:
config = load_config(args.experiment.replace("\\", "/"))
config.set_seed()

src_paths, draft_paths = get_draft_paths_from_exp(config)
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
postprocess_configs = yaml.safe_load(file).get("postprocess", [])

postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False)

for src_path, draft_path in zip(src_paths, draft_paths):
postprocess_draft(src_path, draft_path, postprocess_handler)


if __name__ == "__main__":
main()
44 changes: 14 additions & 30 deletions silnlp/nmt/translate.py
Original file line number Diff line number Diff line change
@@ -4,12 +4,13 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union

from machine.scripture import VerseRef, book_number_to_id, get_chapters

from ..common.environment import SIL_NLP_ENV
from ..common.paratext import book_file_name_digits, get_project_dir
from ..common.postprocesser import PostprocessConfig, PostprocessHandler
from ..common.translator import TranslationGroup, Translator
from ..common.utils import get_git_revision_hash, show_attrs
from .clearml_connection import SILClearML
@@ -54,9 +55,7 @@ def translate_books(
trg_project: Optional[str],
trg_iso: Optional[str],
produce_multiple_translations: bool = False,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_handler: PostprocessHandler = PostprocessHandler(),
):
book_nums = get_chapters(books)
translator, config, step_str = self._init_translation_task(
@@ -112,9 +111,7 @@ def translate_books(
produce_multiple_translations,
chapters,
trg_project,
include_paragraph_markers,
include_style_markers,
include_embeds,
postprocess_handler,
experiment_ckpt_str,
)
except Exception as e:
@@ -178,9 +175,7 @@ def translate_files(
src_iso: Optional[str],
trg_iso: Optional[str],
produce_multiple_translations: bool = False,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_handler: PostprocessHandler = PostprocessHandler(),
) -> None:
translator, config, step_str = self._init_translation_task(
experiment_suffix=f"_{self.checkpoint}_{os.path.basename(src)}",
@@ -248,9 +243,7 @@ def translate_files(
src_iso,
trg_iso,
produce_multiple_translations,
include_paragraph_markers=include_paragraph_markers,
include_style_markers=include_style_markers,
include_embeds=include_embeds,
postprocess_handler,
experiment_ckpt_str=experiment_ckpt_str,
)

@@ -383,12 +376,12 @@ def main() -> None:
name=args.experiment, checkpoint=args.checkpoint, clearml_queue=args.clearml_queue, commit=args.commit
)

# For backwards compatibility
if args.preserve_usfm_markers:
args.include_paragraph_markers = True
args.include_style_markers = True
if args.include_inline_elements:
args.include_embeds = True
postprocess_config = {
"include_paragraph_markers": args.include_paragraph_markers or args.preserve_usfm_markers,
"include_style_markers": args.include_style_markers or args.preserve_usfm_markers,
"include_embeds": args.include_embeds or args.include_inline_elements,
}
postprocess_handler = PostprocessHandler([PostprocessConfig(postprocess_config)])

if len(args.books) > 0:
if args.debug:
@@ -400,9 +393,7 @@ def main() -> None:
args.trg_project,
args.trg_iso,
args.multiple_translations,
args.include_paragraph_markers,
args.include_style_markers,
args.include_embeds,
postprocess_handler,
)
elif args.src_prefix is not None:
if args.debug:
@@ -428,14 +419,7 @@ def main() -> None:
)
exit()
translator.translate_files(
args.src,
args.trg,
args.src_iso,
args.trg_iso,
args.multiple_translations,
args.include_paragraph_markers,
args.include_style_markers,
args.include_embeds,
args.src, args.trg, args.src_iso, args.trg_iso, args.multiple_translations, postprocess_handler
)
else:
raise RuntimeError("A Scripture book, file, or file prefix must be specified.")