Skip to content

Commit 4b1c071

Browse files
Merge pull request bfsujason#5 from ProMeText/add_device_conf
Add device selection
2 parents 5e90057 + 03d5a8a commit 4b1c071

File tree

5 files changed

+46
-41
lines changed

5 files changed

+46
-41
lines changed

aquilign/align/bertalign/Bertalign.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

aquilign/align/bertalign/aligner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
22

3-
from aquilign.align.bertalign.Bertalign import model
43
import aquilign.align.bertalign.corelib as core
54
import aquilign.align.bertalign.utils as utils
65
import torch.nn as nn
76
import torch
87

98
class Bertalign:
109
def __init__(self,
10+
model,
1111
src,
1212
tgt,
1313
max_align=3,
@@ -17,14 +17,17 @@ def __init__(self,
1717
margin=True,
1818
len_penalty=True,
1919
is_split=False,
20-
):
20+
device="cpu"):
2121

2222
self.max_align = max_align
2323
self.top_k = top_k
2424
self.win = win
2525
self.skip = skip
2626
self.margin = margin
2727
self.len_penalty = len_penalty
28+
self.device = device
29+
self.model = model
30+
2831

2932

3033

@@ -38,11 +41,11 @@ def __init__(self,
3841
assert len(src_sents) != 0, "Problemo"
3942

4043
print("Embedding source and target text using {} ...".format(model.model_name))
41-
src_vecs, src_lens = model.transform(src_sents, max_align - 1)
42-
tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1)
44+
src_vecs, src_lens = self.model.transform(src_sents, max_align - 1)
45+
tgt_vecs, tgt_lens = self.model.transform(tgt_sents, max_align - 1)
4346

44-
self.search_simple_vecs = model.simple_vectorization(src_sents)
45-
self.tgt_simple_vecs = model.simple_vectorization(tgt_sents)
47+
self.search_simple_vecs = self.model.simple_vectorization(src_sents)
48+
self.tgt_simple_vecs = self.model.simple_vectorization(tgt_sents)
4649

4750
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
4851

@@ -57,15 +60,15 @@ def __init__(self,
5760
self.tgt_vecs = tgt_vecs
5861

5962
def compute_distance(self):
60-
if torch.cuda.is_available(): # GPU version
63+
if torch.cuda.is_available() and self.device == 'cuda:0': # GPU version
6164
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
6265
output = cos(torch.from_numpy(self.search_simple_vecs), torch.from_numpy(self.tgt_simple_vecs))
6366
return output
6467

6568
def align_sents(self, first_alignment_only=False):
6669

6770
print("Performing first-step alignment ...")
68-
D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k)
71+
D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k, device=self.device)
6972
first_alignment_types = core.get_alignment_types(2) # 0-1, 1-0, 1-1
7073
first_w, first_path = core.find_first_search_path(self.src_num, self.tgt_num)
7174
first_pointers = core.first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)

aquilign/align/bertalign/corelib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def get_alignment_types(max_alignment_size):
377377
alignment_types.append([x, y])
378378
return np.array(alignment_types)
379379

380-
def find_top_k_sents(src_vecs, tgt_vecs, k=3):
380+
def find_top_k_sents(src_vecs, tgt_vecs, k=3, device='cpu'):
381381
"""
382382
Find the top_k similar vecs in tgt_vecs for each vec in src_vecs.
383383
Args:
@@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3):
389389
I: numpy array. Target index matrix of shape (num_src_sents, k).
390390
"""
391391
embedding_size = src_vecs.shape[1]
392-
if torch.cuda.is_available() and platform == 'linux': # GPU version
392+
if torch.cuda.is_available() and platform == 'linux' and device != "cpu": # GPU version
393393
res = faiss.StandardGpuResources()
394394
index = faiss.IndexFlatIP(embedding_size)
395395
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)

aquilign/align/bertalign/encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
class Encoder:
9-
def __init__(self, model_name):
10-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
9+
def __init__(self, model_name, device):
10+
self.device = "cuda:0" if torch.cuda.is_available() and device != "cpu" else "cpu"
1111
if model_name == "LaBSE":
1212
self.model = SentenceTransformer(model_name_or_path=model_name, device=device)
1313
self.model_name = model_name
@@ -21,7 +21,7 @@ def simple_vectorization(self, sents):
2121
This function produces a simple vectorisation of a sentence, without
2222
taking into account its lenght as transform does
2323
"""
24-
sent_vecs = self.model.encode(sents)
24+
sent_vecs = self.model.encode(sents, device=self.device)
2525
return sent_vecs
2626

2727
def transform(self, sents, num_overlaps):
@@ -30,7 +30,7 @@ def transform(self, sents, num_overlaps):
3030
overlaps.append(line)
3131

3232
if self.model_name == "LaBSE":
33-
sent_vecs = self.model.encode(overlaps)
33+
sent_vecs = self.model.encode(overlaps, device=self.device)
3434
else:
3535
sents_vecs = self.t2vec_model.predict()
3636
embedding_dim = sent_vecs.size // (len(sents) * num_overlaps)

main.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import aquilign.align.graph_merge as graph_merge
1010
import aquilign.align.bertalign.utils as utils
1111
import aquilign.align.bertalign.syntactic_tokenization as syntactic_tokenization
12-
from aquilign.align.bertalign.Bertalign import Bertalign
12+
#from aquilign.align.bertalign.Bertalign import Bertalign
13+
from aquilign.align.bertalign.encoder import Encoder
14+
from aquilign.align.bertalign.aligner import Bertalign
1315
import pandas as pd
1416
import argparse
1517
import glob
@@ -60,16 +62,21 @@ class Aligner:
6062
La classe Aligner initialise le moteur d'alignement, fondé sur Bertalign
6163
"""
6264

63-
def __init__(self, corpus_size:None,
65+
def __init__(self,
66+
model,
67+
corpus_size:None,
6468
max_align=3,
6569
out_dir="out",
6670
use_punctuation=True,
6771
input_dir="in",
6872
main_wit=None,
69-
prefix=None):
73+
prefix=None,
74+
device="cpu"):
75+
self.model = model
7076
self.alignment_dict = dict()
7177
self.text_dict = dict()
7278
self.files_path = glob.glob(f"{input_dir}/*.txt")
79+
self.device = device
7380
print(input_dir)
7481
if main_wit is not None:
7582
self.main_file_index = [index for index, path in enumerate(self.files_path) if main_wit in path][0]
@@ -133,7 +140,14 @@ def parallel_align(self):
133140
else:
134141
margin = False
135142
len_penality = True
136-
aligner = Bertalign(first_tokenized_text, second_tokenized_text, max_align= self.max_align, win=5, skip=-.2, margin=margin, len_penalty=len_penality)
143+
aligner = Bertalign(self.model,
144+
first_tokenized_text,
145+
second_tokenized_text,
146+
max_align= self.max_align,
147+
win=5, skip=-.2,
148+
margin=margin,
149+
len_penalty=len_penality,
150+
device=self.device)
137151
aligner.align_sents()
138152

139153
# We append the result to the alignment dictionnary
@@ -215,15 +229,25 @@ def run_alignments():
215229
help="Pivot witness.")
216230
parser.add_argument("-p", "--prefix", default=None,
217231
help="Prefix for produced files.")
232+
parser.add_argument("-d", "--device", default='cpu',
233+
help="Device to be used.")
218234

219235
args = parser.parse_args()
220236
out_dir = args.out_dir
221237
input_dir = args.input_dir
222238
main_wit = args.main_wit
223239
prefix = args.prefix
240+
device = args.device
224241
use_punctuation = args.use_punctuation
242+
243+
# Initialize model
244+
models = {0: "distiluse-base-multilingual-cased-v2", 1: "LaBSE", 2: "Sonar"}
245+
model = Encoder(models[int(1)], device=device)
246+
247+
248+
225249
print(f"Punctuation for tokenization: {use_punctuation}")
226-
MyAligner = Aligner(corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix)
250+
MyAligner = Aligner(model, corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix, device=device)
227251
MyAligner.parallel_align()
228252
utils.write_json(f"result_dir/{out_dir}/alignment_dict.json", MyAligner.alignment_dict)
229253
align_dict = utils.read_json(f"result_dir/{out_dir}/alignment_dict.json")

0 commit comments

Comments
 (0)