diff --git a/seq2seq/data/vocab.py b/seq2seq/data/vocab.py index e4d672ec..06546cd2 100644 --- a/seq2seq/data/vocab.py +++ b/seq2seq/data/vocab.py @@ -84,7 +84,10 @@ def create_vocabulary_lookup_table(filename, default_value=None): has_counts = len(vocab[0].split("\t")) == 2 if has_counts: - vocab, counts = zip(*[_.split("\t") for _ in vocab]) + pairs = [["\t", line.split("\t")[-1]] + if line.startswith("\t\t") else line.split("\t") + for line in vocab] + vocab, counts = zip(*pairs) counts = [float(_) for _ in counts] vocab = list(vocab) else: diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index dedc6594..21fae653 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -47,7 +47,7 @@ def test_begin(self): with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: file_contents = file.read().strip() - self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" + self.assertEqual(file_contents, "_TFProfRoot (--/16.38k params)\n" " weigths (128x128, 16.38k/16.38k params)") outfile.close()