Skip to content

Commit f4377b2

Browse files
committed
Fixing some issues with unassigned questions
1 parent 77390c7 commit f4377b2

File tree

1 file changed

+58
-33
lines changed

1 file changed

+58
-33
lines changed

examples/training_quora_duplicate_questions/create_splits.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import os
4949
from sentence_transformers import util
5050

51+
random.seed(42)
52+
5153
#Get raw file
5254
source_file = "quora-IR-dataset/quora_duplicate_questions.tsv"
5355
os.makedirs('quora-IR-dataset', exist_ok=True)
@@ -104,7 +106,6 @@
104106
#Distribute rows to train/dev/test split
105107
#Ensure that sets contain distinct sentences
106108
is_assigned = set()
107-
random.seed(42)
108109
random.shuffle(rows)
109110

110111
train_ids = set()
@@ -113,15 +114,28 @@
113114

114115
counter = 0
115116
for row in rows:
116-
if row['qid1'] in is_assigned or row['qid2'] in is_assigned:
117+
if row['qid1'] in is_assigned and row['qid2'] in is_assigned:
117118
continue
118-
119-
#Distribution about 85%/5%/10%
120-
target_set = train_ids
121-
if counter%10 == 0:
122-
target_set = dev_ids
123-
elif counter%10 == 1 or counter%10 == 2:
124-
target_set = test_ids
119+
elif row['qid1'] in is_assigned or row['qid2'] in is_assigned:
120+
121+
if row['qid2'] in is_assigned: #Ensure that qid1 is assigned and qid2 not yet
122+
row['qid1'], row['qid2'] = row['qid2'], row['qid1']
123+
124+
#Move qid2 to the same split as qid1
125+
target_set = train_ids
126+
if row['qid1'] in dev_ids:
127+
target_set = dev_ids
128+
elif row['qid1'] in test_ids:
129+
target_set = test_ids
130+
131+
else:
132+
#Distribution about 85%/5%/10%
133+
target_set = train_ids
134+
if counter%10 == 0:
135+
target_set = dev_ids
136+
elif counter%10 == 1 or counter%10 == 2:
137+
target_set = test_ids
138+
counter += 1
125139

126140
#Get the sentence with all duplicates and add it to the respective sets
127141
target_set.add(row['qid1'])
@@ -134,9 +148,14 @@
134148
target_set.add(b)
135149
is_assigned.add(b)
136150

137-
counter += 1
138151

139-
print("Train sentences:", len(train_ids))
152+
#Assert all sets are mutually exclusive
153+
assert len(train_ids.intersection(dev_ids)) == 0
154+
assert len(train_ids.intersection(test_ids)) == 0
155+
assert len(test_ids.intersection(dev_ids)) == 0
156+
157+
158+
print("\nTrain sentences:", len(train_ids))
140159
print("Dev sentences:", len(dev_ids))
141160
print("Test sentences:", len(test_ids))
142161

@@ -154,8 +173,8 @@ def get_duplicate_set(ids_set):
154173
test_duplicates = get_duplicate_set(test_ids)
155174

156175

157-
print("Train duplicates", len(train_duplicates))
158-
print("dev duplicates", len(dev_duplicates))
176+
print("\nTrain duplicates", len(train_duplicates))
177+
print("Dev duplicates", len(dev_duplicates))
159178
print("Test duplicates", len(test_duplicates))
160179

161180
############### Write general files about the duplate questions graph ############
@@ -174,7 +193,7 @@ def get_duplicate_set(ids_set):
174193
duplicates_list = sorted(duplicates_list, key=lambda x: x[0]*1000000+x[1])
175194

176195

177-
print("Write duplicate graph in pairwise format")
196+
print("\nWrite duplicate graph in pairwise format")
178197
with open('quora-IR-dataset/graph/duplicates-graph-pairwise.tsv', 'w', encoding='utf8') as fOut:
179198
fOut.write("qid1\tqid2\n")
180199
for a, b in duplicates_list:
@@ -192,7 +211,7 @@ def get_duplicate_set(ids_set):
192211
def write_qids(name, ids_list):
193212
with open('quora-IR-dataset/graph/'+name+'-questions.tsv', 'w', encoding='utf8') as fOut:
194213
fOut.write("qid\n")
195-
fOut.write("\n".join(sorted(ids_list)))
214+
fOut.write("\n".join(sorted(ids_list, key=lambda x: int(x))))
196215

197216
write_qids('train', train_ids)
198217
write_qids('dev', dev_ids)
@@ -249,54 +268,60 @@ def write_mining_files(name, ids, dups):
249268
test_queries = set()
250269

251270
#Create dev queries
252-
for a, b in dev_duplicates:
253-
if a not in corpus_ids and b not in corpus_ids:
254-
if len(dev_queries) < num_dev_queries:
271+
rnd_dev_ids = sorted(list(dev_ids))
272+
random.shuffle(rnd_dev_ids)
273+
274+
for a in rnd_dev_ids:
275+
if a not in corpus_ids:
276+
if len(dev_queries) < num_dev_queries and len(duplicates[a]) > 0:
255277
dev_queries.add(a)
256278
else:
257279
corpus_ids.add(a)
258280

259-
corpus_ids.add(b)
260-
for further_dups in duplicates[b]:
261-
if further_dups not in dev_queries:
262-
corpus_ids.add(further_dups)
281+
for b in duplicates[a]:
282+
if b not in dev_queries:
283+
corpus_ids.add(b)
263284

264285
#Create test queries
265-
for a, b in test_duplicates:
266-
if a not in corpus_ids and b not in corpus_ids:
267-
if len(test_queries) < num_test_queries:
286+
rnd_test_ids = sorted(list(test_ids))
287+
random.shuffle(rnd_test_ids)
288+
289+
for a in rnd_test_ids:
290+
if a not in corpus_ids:
291+
if len(test_queries) < num_test_queries and len(duplicates[a]) > 0:
268292
test_queries.add(a)
269293
else:
270294
corpus_ids.add(a)
271295

272-
corpus_ids.add(b)
273-
for further_dups in duplicates[b]:
274-
if further_dups not in test_queries:
275-
corpus_ids.add(further_dups)
296+
for b in duplicates[a]:
297+
if b not in test_queries:
298+
corpus_ids.add(b)
276299

277300
#Write output for information-retrieval
301+
print("\nInformation Retrival Setup")
278302
print("Corpus size:", len(corpus_ids))
279303
print("Dev queries:", len(dev_queries))
280304
print("Test queries:", len(test_queries))
281305

282306
with open('quora-IR-dataset/information-retrieval/corpus.tsv', 'w', encoding='utf8') as fOut:
283307
fOut.write("qid\tquestion\n")
284-
for id in corpus_ids:
308+
for id in sorted(corpus_ids, key=lambda id: int(id)):
285309
fOut.write("{}\t{}\n".format(id, sentences[id]))
286310

287311
with open('quora-IR-dataset/information-retrieval/dev-queries.tsv', 'w', encoding='utf8') as fOut:
288312
fOut.write("qid\tquestion\tduplicate_qids\n")
289-
for id in dev_queries:
313+
for id in sorted(dev_queries, key=lambda id: int(id)):
290314
fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))
291315

292316
with open('quora-IR-dataset/information-retrieval/test-queries.tsv', 'w', encoding='utf8') as fOut:
293317
fOut.write("qid\tquestion\tduplicate_qids\n")
294-
for id in test_queries:
318+
for id in sorted(test_queries, key=lambda id: int(id)):
295319
fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))
296320

297321

322+
print("--DONE--")
323+
298324

299325

300326

301327

302-
print("--DONE--")

0 commit comments

Comments
 (0)