|
9 | 9 | import aquilign.align.graph_merge as graph_merge |
10 | 10 | import aquilign.align.bertalign.utils as utils |
11 | 11 | import aquilign.align.bertalign.syntactic_tokenization as syntactic_tokenization |
12 | | -from aquilign.align.bertalign.Bertalign import Bertalign |
| 12 | +#from aquilign.align.bertalign.Bertalign import Bertalign |
| 13 | +from aquilign.align.bertalign.encoder import Encoder |
| 14 | +from aquilign.align.bertalign.aligner import Bertalign |
13 | 15 | import pandas as pd |
14 | 16 | import argparse |
15 | 17 | import glob |
@@ -60,16 +62,21 @@ class Aligner: |
60 | 62 | La classe Aligner initialise le moteur d'alignement, fondé sur Bertalign |
61 | 63 | """ |
62 | 64 |
|
63 | | - def __init__(self, corpus_size:None, |
| 65 | + def __init__(self, |
| 66 | + model, |
| 67 | + corpus_size:None, |
64 | 68 | max_align=3, |
65 | 69 | out_dir="out", |
66 | 70 | use_punctuation=True, |
67 | 71 | input_dir="in", |
68 | 72 | main_wit=None, |
69 | | - prefix=None): |
| 73 | + prefix=None, |
| 74 | + device="cpu"): |
| 75 | + self.model = model |
70 | 76 | self.alignment_dict = dict() |
71 | 77 | self.text_dict = dict() |
72 | 78 | self.files_path = glob.glob(f"{input_dir}/*.txt") |
| 79 | + self.device = device |
73 | 80 | print(input_dir) |
74 | 81 | if main_wit is not None: |
75 | 82 | self.main_file_index = [index for index, path in enumerate(self.files_path) if main_wit in path][0] |
@@ -133,7 +140,14 @@ def parallel_align(self): |
133 | 140 | else: |
134 | 141 | margin = False |
135 | 142 | len_penality = True |
136 | | - aligner = Bertalign(first_tokenized_text, second_tokenized_text, max_align= self.max_align, win=5, skip=-.2, margin=margin, len_penalty=len_penality) |
| 143 | + aligner = Bertalign(self.model, |
| 144 | + first_tokenized_text, |
| 145 | + second_tokenized_text, |
| 146 | + max_align= self.max_align, |
| 147 | + win=5, skip=-.2, |
| 148 | + margin=margin, |
| 149 | + len_penalty=len_penality, |
| 150 | + device=self.device) |
137 | 151 | aligner.align_sents() |
138 | 152 |
|
139 | 153 | # We append the result to the alignment dictionnary |
@@ -215,15 +229,25 @@ def run_alignments(): |
215 | 229 | help="Pivot witness.") |
216 | 230 | parser.add_argument("-p", "--prefix", default=None, |
217 | 231 | help="Prefix for produced files.") |
| 232 | + parser.add_argument("-d", "--device", default='cpu', |
| 233 | + help="Device to be used.") |
218 | 234 |
|
219 | 235 | args = parser.parse_args() |
220 | 236 | out_dir = args.out_dir |
221 | 237 | input_dir = args.input_dir |
222 | 238 | main_wit = args.main_wit |
223 | 239 | prefix = args.prefix |
| 240 | + device = args.device |
224 | 241 | use_punctuation = args.use_punctuation |
| 242 | + |
| 243 | + # Initialize model |
| 244 | + models = {0: "distiluse-base-multilingual-cased-v2", 1: "LaBSE", 2: "Sonar"} |
| 245 | + model = Encoder(models[int(1)], device=device) |
| 246 | + |
| 247 | + |
| 248 | + |
225 | 249 | print(f"Punctuation for tokenization: {use_punctuation}") |
226 | | - MyAligner = Aligner(corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix) |
| 250 | + MyAligner = Aligner(model, corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix, device=device) |
227 | 251 | MyAligner.parallel_align() |
228 | 252 | utils.write_json(f"result_dir/{out_dir}/alignment_dict.json", MyAligner.alignment_dict) |
229 | 253 | align_dict = utils.read_json(f"result_dir/{out_dir}/alignment_dict.json") |
|
0 commit comments