4848import os
4949from sentence_transformers import util
5050
51+ random .seed (42 )
52+
5153#Get raw file
5254source_file = "quora-IR-dataset/quora_duplicate_questions.tsv"
5355os .makedirs ('quora-IR-dataset' , exist_ok = True )
104106#Distribute rows to train/dev/test split
105107#Ensure that sets contain distinct sentences
106108is_assigned = set ()
107- random .seed (42 )
108109random .shuffle (rows )
109110
110111train_ids = set ()
113114
114115counter = 0
115116for row in rows :
116- if row ['qid1' ] in is_assigned or row ['qid2' ] in is_assigned :
117+ if row ['qid1' ] in is_assigned and row ['qid2' ] in is_assigned :
117118 continue
118-
119- #Distribution about 85%/5%/10%
120- target_set = train_ids
121- if counter % 10 == 0 :
122- target_set = dev_ids
123- elif counter % 10 == 1 or counter % 10 == 2 :
124- target_set = test_ids
119+ elif row ['qid1' ] in is_assigned or row ['qid2' ] in is_assigned :
120+
121+ if row ['qid2' ] in is_assigned : #Ensure that qid1 is assigned and qid2 not yet
122+ row ['qid1' ], row ['qid2' ] = row ['qid2' ], row ['qid1' ]
123+
124+ #Move qid2 to the same split as qid1
125+ target_set = train_ids
126+ if row ['qid1' ] in dev_ids :
127+ target_set = dev_ids
128+ elif row ['qid1' ] in test_ids :
129+ target_set = test_ids
130+
131+ else :
132+ #Distribution about 85%/5%/10%
133+ target_set = train_ids
134+ if counter % 10 == 0 :
135+ target_set = dev_ids
136+ elif counter % 10 == 1 or counter % 10 == 2 :
137+ target_set = test_ids
138+ counter += 1
125139
126140 #Get the sentence with all duplicates and add it to the respective sets
127141 target_set .add (row ['qid1' ])
134148 target_set .add (b )
135149 is_assigned .add (b )
136150
137- counter += 1
138151
139- print ("Train sentences:" , len (train_ids ))
152+ #Assert all sets are mutually exclusive
153+ assert len (train_ids .intersection (dev_ids )) == 0
154+ assert len (train_ids .intersection (test_ids )) == 0
155+ assert len (test_ids .intersection (dev_ids )) == 0
156+
157+
158+ print ("\n Train sentences:" , len (train_ids ))
140159print ("Dev sentences:" , len (dev_ids ))
141160print ("Test sentences:" , len (test_ids ))
142161
@@ -154,8 +173,8 @@ def get_duplicate_set(ids_set):
154173test_duplicates = get_duplicate_set (test_ids )
155174
156175
157- print ("Train duplicates" , len (train_duplicates ))
158- print ("dev duplicates" , len (dev_duplicates ))
176+ print ("\n Train duplicates" , len (train_duplicates ))
177+ print ("Dev duplicates" , len (dev_duplicates ))
159178print ("Test duplicates" , len (test_duplicates ))
160179
161180############### Write general files about the duplate questions graph ############
@@ -174,7 +193,7 @@ def get_duplicate_set(ids_set):
174193duplicates_list = sorted (duplicates_list , key = lambda x : x [0 ]* 1000000 + x [1 ])
175194
176195
177- print ("Write duplicate graph in pairwise format" )
196+ print ("\n Write duplicate graph in pairwise format" )
178197with open ('quora-IR-dataset/graph/duplicates-graph-pairwise.tsv' , 'w' , encoding = 'utf8' ) as fOut :
179198 fOut .write ("qid1\t qid2\n " )
180199 for a , b in duplicates_list :
@@ -192,7 +211,7 @@ def get_duplicate_set(ids_set):
192211def write_qids (name , ids_list ):
193212 with open ('quora-IR-dataset/graph/' + name + '-questions.tsv' , 'w' , encoding = 'utf8' ) as fOut :
194213 fOut .write ("qid\n " )
195- fOut .write ("\n " .join (sorted (ids_list )))
214+ fOut .write ("\n " .join (sorted (ids_list , key = lambda x : int ( x ) )))
196215
197216write_qids ('train' , train_ids )
198217write_qids ('dev' , dev_ids )
@@ -249,54 +268,60 @@ def write_mining_files(name, ids, dups):
249268test_queries = set ()
250269
251270#Create dev queries
252- for a , b in dev_duplicates :
253- if a not in corpus_ids and b not in corpus_ids :
254- if len (dev_queries ) < num_dev_queries :
271+ rnd_dev_ids = sorted (list (dev_ids ))
272+ random .shuffle (rnd_dev_ids )
273+
274+ for a in rnd_dev_ids :
275+ if a not in corpus_ids :
276+ if len (dev_queries ) < num_dev_queries and len (duplicates [a ]) > 0 :
255277 dev_queries .add (a )
256278 else :
257279 corpus_ids .add (a )
258280
259- corpus_ids .add (b )
260- for further_dups in duplicates [b ]:
261- if further_dups not in dev_queries :
262- corpus_ids .add (further_dups )
281+ for b in duplicates [a ]:
282+ if b not in dev_queries :
283+ corpus_ids .add (b )
263284
264285#Create test queries
265- for a , b in test_duplicates :
266- if a not in corpus_ids and b not in corpus_ids :
267- if len (test_queries ) < num_test_queries :
286+ rnd_test_ids = sorted (list (test_ids ))
287+ random .shuffle (rnd_test_ids )
288+
289+ for a in rnd_test_ids :
290+ if a not in corpus_ids :
291+ if len (test_queries ) < num_test_queries and len (duplicates [a ]) > 0 :
268292 test_queries .add (a )
269293 else :
270294 corpus_ids .add (a )
271295
272- corpus_ids .add (b )
273- for further_dups in duplicates [b ]:
274- if further_dups not in test_queries :
275- corpus_ids .add (further_dups )
296+ for b in duplicates [a ]:
297+ if b not in test_queries :
298+ corpus_ids .add (b )
276299
277300#Write output for information-retrieval
301+ print ("\n Information Retrival Setup" )
278302print ("Corpus size:" , len (corpus_ids ))
279303print ("Dev queries:" , len (dev_queries ))
280304print ("Test queries:" , len (test_queries ))
281305
282306with open ('quora-IR-dataset/information-retrieval/corpus.tsv' , 'w' , encoding = 'utf8' ) as fOut :
283307 fOut .write ("qid\t question\n " )
284- for id in corpus_ids :
308+ for id in sorted ( corpus_ids , key = lambda id : int ( id )) :
285309 fOut .write ("{}\t {}\n " .format (id , sentences [id ]))
286310
287311with open ('quora-IR-dataset/information-retrieval/dev-queries.tsv' , 'w' , encoding = 'utf8' ) as fOut :
288312 fOut .write ("qid\t question\t duplicate_qids\n " )
289- for id in dev_queries :
313+ for id in sorted ( dev_queries , key = lambda id : int ( id )) :
290314 fOut .write ("{}\t {}\t {}\n " .format (id , sentences [id ], "," .join (duplicates [id ])))
291315
292316with open ('quora-IR-dataset/information-retrieval/test-queries.tsv' , 'w' , encoding = 'utf8' ) as fOut :
293317 fOut .write ("qid\t question\t duplicate_qids\n " )
294- for id in test_queries :
318+ for id in sorted ( test_queries , key = lambda id : int ( id )) :
295319 fOut .write ("{}\t {}\t {}\n " .format (id , sentences [id ], "," .join (duplicates [id ])))
296320
297321
322+ print ("--DONE--" )
323+
298324
299325
300326
301327
302- print ("--DONE--" )
0 commit comments