diff --git a/bert.py b/bert.py index 64bcc2b..8342606 100644 --- a/bert.py +++ b/bert.py @@ -548,7 +548,7 @@ def main(_): if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) - processor = processors[task_name]() + processor = processors[task_name](bert=True) label_list = processor.get_labels() diff --git a/data_processors.py b/data_processors.py index 77a3119..dc69519 100644 --- a/data_processors.py +++ b/data_processors.py @@ -94,6 +94,9 @@ def _read_tsv(cls, input_file, quotechar=None): class QcFineProcessor(DataProcessor): """Processor for the MultiNLI data set (GLUE version).""" + def __init__(self, bert=False): + self.bert = bert + def get_labeled_examples(self, data_dir): """See base class.""" return self._create_examples(os.path.join(data_dir, "labeled.tsv"), "train") @@ -108,7 +111,14 @@ def get_test_examples(self, data_dir): def get_labels(self): """See base class.""" - return ["UNK_UNK", "ABBR_abb", "ABBR_exp", "DESC_def", "DESC_desc", "DESC_manner", "DESC_reason", "ENTY_animal", "ENTY_body", "ENTY_color", "ENTY_cremat", "ENTY_currency", "ENTY_dismed", "ENTY_event", "ENTY_food", "ENTY_instru", "ENTY_lang", "ENTY_letter", "ENTY_other", "ENTY_plant", "ENTY_product", "ENTY_religion", "ENTY_sport", "ENTY_substance", "ENTY_symbol", "ENTY_techmeth", "ENTY_termeq", "ENTY_veh", "ENTY_word", "HUM_desc", "HUM_gr", "HUM_ind", "HUM_title", "LOC_city", "LOC_country", "LOC_mount", "LOC_other", "LOC_state", "NUM_code", "NUM_count", "NUM_date", "NUM_dist", "NUM_money", "NUM_ord", "NUM_other", "NUM_perc", "NUM_period", "NUM_speed", "NUM_temp", "NUM_volsize", "NUM_weight"] + label_list = ["UNK_UNK", "ABBR_abb", "ABBR_exp", "DESC_def", "DESC_desc", "DESC_manner", "DESC_reason", "ENTY_animal", "ENTY_body", "ENTY_color", "ENTY_cremat", "ENTY_currency", "ENTY_dismed", "ENTY_event", "ENTY_food", "ENTY_instru", "ENTY_lang", "ENTY_letter", "ENTY_other", "ENTY_plant", "ENTY_product", "ENTY_religion", "ENTY_sport", "ENTY_substance", "ENTY_symbol", "ENTY_techmeth", "ENTY_termeq", "ENTY_veh", "ENTY_word", "HUM_desc", "HUM_gr", "HUM_ind", "HUM_title", "LOC_city", "LOC_country", "LOC_mount", "LOC_other", "LOC_state", "NUM_code", "NUM_count", "NUM_date", "NUM_dist", "NUM_money", "NUM_ord", "NUM_other", "NUM_perc", "NUM_period", "NUM_speed", "NUM_temp", "NUM_volsize", "NUM_weight"] + if self.bert: + # We do not need "UNK_UNK" label when using the original bert. + # In fact using a label that will not have any example in the training data + # can have serious consequences if your label_list is small. + label_list = label_list[1:] + + return label_list def _create_examples(self, input_file, set_type): """Creates examples for the training and dev sets."""