From a31c1fd1fa2413bc00f39ff5e41ccdc5c538e314 Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 11 Jun 2017 15:20:22 +0100 Subject: [PATCH 01/12] Update imports to support TensorFlow v1.2. --- seq2seq/contrib/seq2seq/helper.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/seq2seq/contrib/seq2seq/helper.py b/seq2seq/contrib/seq2seq/helper.py index 977d0ab9..353a1386 100644 --- a/seq2seq/contrib/seq2seq/helper.py +++ b/seq2seq/contrib/seq2seq/helper.py @@ -32,8 +32,13 @@ import six -from tensorflow.contrib.distributions.python.ops import bernoulli -from tensorflow.contrib.distributions.python.ops import categorical +try: + from tensorflow.python.ops.distributions import bernoulli + from tensorflow.python.ops.distributions import categorical +except: + # Backwards compatibility with TensorFlow prior to 1.2. + from tensorflow.contrib.distributions.python.ops import bernoulli + from tensorflow.contrib.distributions.python.ops import categorical from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import base as layers_base From 8053094a02340492e8f2ff5e594714dd7b17071d Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 11 Jun 2017 15:18:50 +0100 Subject: [PATCH 02/12] Fix typo in hooks_test.py. --- seq2seq/test/hooks_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index dedc6594..a04cdabc 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -39,7 +39,7 @@ class TestPrintModelAnalysisHook(tf.test.TestCase): def test_begin(self): model_dir = tempfile.mkdtemp() outfile = tempfile.NamedTemporaryFile() - tf.get_variable("weigths", [128, 128]) + tf.get_variable("weights", [128, 128]) hook = hooks.PrintModelAnalysisHook( params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() @@ -125,7 +125,7 @@ def tearDown(self): def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation - some_weights = tf.get_variable("weigths", [2, 128]) + some_weights = tf.get_variable("weights", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook( From 4448976cbc0854ae406b4cfeeb235b705b95102a Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 11 Jun 2017 15:19:27 +0100 Subject: [PATCH 03/12] Update hooks_test.py to work with TensorFlow 1.2. --- seq2seq/test/hooks_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index a04cdabc..70d395e7 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -24,6 +24,7 @@ import tempfile import shutil import time +from distutils.version import LooseVersion import tensorflow as tf from tensorflow.python.training import monitored_session # pylint: disable=E0611 @@ -47,8 +48,14 @@ def test_begin(self): with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: file_contents = file.read().strip() - self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" - " weigths (128x128, 16.38k/16.38k params)") + if LooseVersion(tf.VERSION) < LooseVersion("1.2.0"): + self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" + " weights (128x128, 16.38k/16.38k params)") + else: + self.assertEqual(file_contents.decode(), "node name | # parameters\n" + "_TFProfRoot (--/16.38k params)\n" + " weights (128x128, 16.38k/16.38k params)") + outfile.close() From f85d026c0f7bff2c06c991012271e4c4fd1af1a6 Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 11 Jun 2017 18:03:43 +0100 Subject: [PATCH 04/12] Fix hooks_test.py on Python 3. Fixes the following error when running the tests on Python 3. Traceback (most recent call last): File "seq2seq/seq2seq/test/hooks_test.py", line 55, in test_begin self.assertEqual(file_contents.decode(), ...) AttributeError: 'str' object has no attribute 'decode' --- seq2seq/test/hooks_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index 70d395e7..8fe4db01 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -49,10 +49,10 @@ def test_begin(self): file_contents = file.read().strip() if LooseVersion(tf.VERSION) < LooseVersion("1.2.0"): - self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" + self.assertEqual(tf.compat.as_text(file_contents), "_TFProfRoot (--/16.38k params)\n" " weights (128x128, 16.38k/16.38k params)") else: - self.assertEqual(file_contents.decode(), "node name | # parameters\n" + self.assertEqual(tf.compat.as_text(file_contents), "node name | # parameters\n" "_TFProfRoot (--/16.38k params)\n" " weights (128x128, 16.38k/16.38k params)") From 83d592f2e9d00c6cb1597fd27b8dcdd2b3277e70 Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 11 Jun 2017 18:21:40 +0100 Subject: [PATCH 05/12] Another go at fixing hooks_test.py. distutils.version isn't available on the CI versions, so change the test to not rely explicitly on the version of TensorFlow installed. --- seq2seq/test/hooks_test.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index 8fe4db01..abea8125 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -24,7 +24,6 @@ import tempfile import shutil import time -from distutils.version import LooseVersion import tensorflow as tf from tensorflow.python.training import monitored_session # pylint: disable=E0611 @@ -48,13 +47,13 @@ def test_begin(self): with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: file_contents = file.read().strip() - if LooseVersion(tf.VERSION) < LooseVersion("1.2.0"): - self.assertEqual(tf.compat.as_text(file_contents), "_TFProfRoot (--/16.38k params)\n" - " weights (128x128, 16.38k/16.38k params)") - else: - self.assertEqual(tf.compat.as_text(file_contents), "node name | # parameters\n" - "_TFProfRoot (--/16.38k params)\n" - " weights (128x128, 16.38k/16.38k params)") + lines = tf.compat.as_text(file_contents).split("\n") + if len(lines) == 3: + # TensorFlow v1.2 includes an extra header line + self.assertEqual(lines[0], "node name | # parameters") + + self.assertEqual(lines[-2], "_TFProfRoot (--/16.38k params)") + self.assertEqual(lines[-1], " weights (128x128, 16.38k/16.38k params)") outfile.close() From ffb1505655b3e3836b054da2b0b2981da2e29a5c Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 18 Jun 2017 14:36:28 +0100 Subject: [PATCH 06/12] Tell pylint about the tf contextmanager. Use of `tf.name_scope` and `tf.variable_scope` cause pylint errors on TF1.2 due to pylint not fully understanding the `tf_contextlib.contextmanager` decorators. --- pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylintrc b/pylintrc index 2a182c2f..0e75ca0c 100644 --- a/pylintrc +++ b/pylintrc @@ -292,7 +292,7 @@ generated-members=set_shape,np.float32 # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. -contextmanager-decorators=contextlib.contextmanager +contextmanager-decorators=contextlib.contextmanager,tensorflow.python.util.tf_contextlib.contextmanager [VARIABLES] From 0c20ee4d8546dd905871049825d0883eda491bca Mon Sep 17 00:00:00 2001 From: Darren Garvey Date: Sun, 18 Jun 2017 15:00:57 +0100 Subject: [PATCH 07/12] Work around no-name-in-module pylint errors. Following some changes in TF to the `LazyLoader` [1], pylint complains about not being able to find some imports under `tf.contrib`. This looks like a pylint issue, emitting errors like: ************* Module seq2seq.encoders.rnn_encoder E: 24, 0: No name 'rnn' in module 'LazyLoader' (no-name-in-module) ************* Module seq2seq.data.input_pipeline E: 32, 0: No name 'slim' in module 'LazyLoader' (no-name-in-module) [1] https://github.com/tensorflow/tensorflow/commit/95c5d7e880b8b4d18ba1f3b7cf40d15cd218b3c9 --- seq2seq/data/input_pipeline.py | 2 ++ seq2seq/data/parallel_data_provider.py | 2 ++ seq2seq/data/sequence_example_decoder.py | 2 ++ seq2seq/data/split_tokens_decoder.py | 2 ++ seq2seq/encoders/image_encoder.py | 2 ++ seq2seq/encoders/rnn_encoder.py | 3 +-- seq2seq/metrics/metric_specs.py | 2 ++ 7 files changed, 13 insertions(+), 2 deletions(-) diff --git a/seq2seq/data/input_pipeline.py b/seq2seq/data/input_pipeline.py index 1d22b5cc..b452adf0 100644 --- a/seq2seq/data/input_pipeline.py +++ b/seq2seq/data/input_pipeline.py @@ -29,7 +29,9 @@ import six import tensorflow as tf +# pylint: disable=no-name-in-module from tensorflow.contrib.slim.python.slim.data import tfexample_decoder +# pylint: enable=no-name-in-module from seq2seq.configurable import Configurable from seq2seq.data import split_tokens_decoder, parallel_data_provider diff --git a/seq2seq/data/parallel_data_provider.py b/seq2seq/data/parallel_data_provider.py index 269add9d..c07046bb 100644 --- a/seq2seq/data/parallel_data_provider.py +++ b/seq2seq/data/parallel_data_provider.py @@ -22,8 +22,10 @@ import numpy as np import tensorflow as tf +# pylint: disable=no-name-in-module from tensorflow.contrib.slim.python.slim.data import data_provider from tensorflow.contrib.slim.python.slim.data import parallel_reader +# pylint: enable=no-name-in-module from seq2seq.data import split_tokens_decoder diff --git a/seq2seq/data/sequence_example_decoder.py b/seq2seq/data/sequence_example_decoder.py index 6eb76181..cc286c5d 100644 --- a/seq2seq/data/sequence_example_decoder.py +++ b/seq2seq/data/sequence_example_decoder.py @@ -14,7 +14,9 @@ """A decoder for tf.SequenceExample""" import tensorflow as tf +# pylint: disable=no-name-in-module from tensorflow.contrib.slim.python.slim.data import data_decoder +# pylint: enable=no-name-in-module class TFSEquenceExampleDecoder(data_decoder.DataDecoder): diff --git a/seq2seq/data/split_tokens_decoder.py b/seq2seq/data/split_tokens_decoder.py index c6c1efe2..2dd04200 100644 --- a/seq2seq/data/split_tokens_decoder.py +++ b/seq2seq/data/split_tokens_decoder.py @@ -21,7 +21,9 @@ from __future__ import unicode_literals import tensorflow as tf +# pylint: disable=no-name-in-module from tensorflow.contrib.slim.python.slim.data import data_decoder +# pylint: enable=no-name-in-module class SplitTokensDecoder(data_decoder.DataDecoder): diff --git a/seq2seq/encoders/image_encoder.py b/seq2seq/encoders/image_encoder.py index f8dac5d8..b8edebbe 100644 --- a/seq2seq/encoders/image_encoder.py +++ b/seq2seq/encoders/image_encoder.py @@ -20,8 +20,10 @@ from __future__ import print_function import tensorflow as tf +# pylint: disable=no-name-in-module from tensorflow.contrib.slim.python.slim.nets.inception_v3 \ import inception_v3_base +# pylint: enable=no-name-in-module from seq2seq.encoders.encoder import Encoder, EncoderOutput diff --git a/seq2seq/encoders/rnn_encoder.py b/seq2seq/encoders/rnn_encoder.py index d21338df..2cbf4a24 100644 --- a/seq2seq/encoders/rnn_encoder.py +++ b/seq2seq/encoders/rnn_encoder.py @@ -21,7 +21,6 @@ import copy import tensorflow as tf -from tensorflow.contrib.rnn.python.ops import rnn from seq2seq.encoders.encoder import Encoder, EncoderOutput from seq2seq.training import utils as training_utils @@ -186,7 +185,7 @@ def encode(self, inputs, sequence_length, **kwargs): cells_fw = _unpack_cell(cell_fw) cells_bw = _unpack_cell(cell_bw) - result = rnn.stack_bidirectional_dynamic_rnn( + result = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw=cells_fw, cells_bw=cells_bw, inputs=inputs, diff --git a/seq2seq/metrics/metric_specs.py b/seq2seq/metrics/metric_specs.py index e4c4ceaa..c6da8709 100644 --- a/seq2seq/metrics/metric_specs.py +++ b/seq2seq/metrics/metric_specs.py @@ -28,7 +28,9 @@ import tensorflow as tf from tensorflow.contrib import metrics +# pylint: disable=no-name-in-module from tensorflow.contrib.learn import MetricSpec +# pylint: enable=no-name-in-module from seq2seq.data import postproc from seq2seq.configurable import Configurable From 6ef2992d1bbd934fcb902c1b044cd67fbf55de9f Mon Sep 17 00:00:00 2001 From: pratapbhanu Date: Wed, 5 Jul 2017 13:32:23 +0200 Subject: [PATCH 08/12] adding changes for using pre-trained embeddings as source embeddings --- seq2seq/data/embeddings.py | 36 +++++++++++++++++++++++ seq2seq/data/vocab.py | 51 +++++++++++++++++++++------------ seq2seq/models/seq2seq_model.py | 25 ++++++++++++---- seq2seq/test/models_test.py | 29 +++++++++++++++++++ seq2seq/test/utils.py | 21 ++++++++++++++ 5 files changed, 137 insertions(+), 25 deletions(-) create mode 100644 seq2seq/data/embeddings.py diff --git a/seq2seq/data/embeddings.py b/seq2seq/data/embeddings.py new file mode 100644 index 00000000..f4e8ead0 --- /dev/null +++ b/seq2seq/data/embeddings.py @@ -0,0 +1,36 @@ +''' +Created on Jul 4, 2017 + +@author: Bhanu + +''' +import numpy as np +from seq2seq.data import vocab + + +def read_embeddings(embeddings_path, vocab_path): + vocab_, _, _ = vocab.read_vocab(vocab_path) + word2vec = {} + with open(embeddings_path, 'rt', encoding='utf-8') as vec_file: + for line in vec_file: + parts = line.split(',') + word = parts[0] + + if word not in vocab_: continue + + vec = parts[1:] + word2vec[word] = vec + + unknown_words = [w for w in vocab_ if w not in word2vec] + emb_dim = len(word2vec.get(vocab_[0])) + rnd_vecs = [np.random.uniform(-0.25, 0.25, + size=emb_dim).tolist() for _ in unknown_words] + print("adding %d unknown words to vocab"%len(unknown_words)) + word2vec.update(dict(zip(unknown_words,rnd_vecs))) + + vecs = [word2vec.get(w) for w in vocab_] + embedding_mat = np.asarray(vecs, dtype=np.float32) + + return embedding_mat + + diff --git a/seq2seq/data/vocab.py b/seq2seq/data/vocab.py index e4d672ec..3460470e 100644 --- a/seq2seq/data/vocab.py +++ b/seq2seq/data/vocab.py @@ -60,6 +60,36 @@ def get_special_vocab(vocabulary_size): return SpecialVocab(*range(vocabulary_size, vocabulary_size + 3)) + +def read_vocab(filename): + """Reads vocab file into the memory and adds special-vocab to it. + + Args: + filename: Path to a vocabulary file containg one word per line. + Each word is mapped to its line number. + + Returns: + A tuple (vocab, counts, special_vocab) + """ + tf.logging.info("Reading vocabulary from %s"%filename) + with gfile.GFile( # Load vocabulary into memory + filename) as file: + vocab = list(line.strip("\n") for line in file) + vocab_size = len(vocab) + has_counts = len(vocab[0].split("\t")) == 2 + if has_counts: + vocab, counts = zip(*[_.split("\t") for _ in vocab]) + counts = [float(_) for _ in counts] + vocab = list(vocab) + else: + counts = [-1. for _ in vocab] + # Add special vocabulary items + special_vocab = get_special_vocab(vocab_size) + vocab += list(special_vocab._fields) + counts += [-1. for _ in list(special_vocab._fields)] + + return vocab, counts, special_vocab + def create_vocabulary_lookup_table(filename, default_value=None): """Creates a lookup table for a vocabulary file. @@ -77,27 +107,10 @@ def create_vocabulary_lookup_table(filename, default_value=None): if not gfile.Exists(filename): raise ValueError("File does not exist: {}".format(filename)) - # Load vocabulary into memory - with gfile.GFile(filename) as file: - vocab = list(line.strip("\n") for line in file) - vocab_size = len(vocab) - - has_counts = len(vocab[0].split("\t")) == 2 - if has_counts: - vocab, counts = zip(*[_.split("\t") for _ in vocab]) - counts = [float(_) for _ in counts] - vocab = list(vocab) - else: - counts = [-1. for _ in vocab] - - # Add special vocabulary items - special_vocab = get_special_vocab(vocab_size) - vocab += list(special_vocab._fields) - vocab_size += len(special_vocab) - counts += [-1. for _ in list(special_vocab._fields)] - + vocab, counts, special_vocab = read_vocab(filename) if default_value is None: default_value = special_vocab.UNK + vocab_size = len(vocab) tf.logging.info("Creating vocabulary lookup table of size %d", vocab_size) diff --git a/seq2seq/models/seq2seq_model.py b/seq2seq/models/seq2seq_model.py index 423ffb75..db5ab150 100644 --- a/seq2seq/models/seq2seq_model.py +++ b/seq2seq/models/seq2seq_model.py @@ -30,6 +30,7 @@ from seq2seq.decoders.beam_search_decoder import BeamSearchDecoder from seq2seq.inference import beam_search from seq2seq.models.model_base import ModelBase, _flatten_dict +from seq2seq.data.embeddings import read_embeddings class Seq2SeqModel(ModelBase): @@ -46,6 +47,11 @@ def __init__(self, params, mode, name): self.target_vocab_info = None if "vocab_target" in self.params and self.params["vocab_target"]: self.target_vocab_info = vocab.get_vocab_info(self.params["vocab_target"]) + + self.embedding_mat = None + if "embedding.file" in self.params and self.params["embedding.file"]: + embedding_mat = read_embeddings(self.params['embedding.file'], + self.source_vocab_info.path) @staticmethod def default_params(): @@ -57,6 +63,8 @@ def default_params(): "embedding.dim": 100, "embedding.init_scale": 0.04, "embedding.share": False, + "embedding.file": None, + "embedding.tune": True, "inference.beam_search.beam_width": 0, "inference.beam_search.length_penalty_weight": 0.0, "inference.beam_search.choose_successors_fn": "choose_top_k", @@ -128,12 +136,17 @@ def batch_size(self, features, labels): def source_embedding(self): """Returns the embedding used for the source sequence. """ - return tf.get_variable( - name="W", - shape=[self.source_vocab_info.total_size, self.params["embedding.dim"]], - initializer=tf.random_uniform_initializer( - -self.params["embedding.init_scale"], - self.params["embedding.init_scale"])) + if self.embedding_mat: + self.params.update({"embedding.dim":self.embedding_mat.shape[1]}) + initializer = tf.constant(self.embedding_mat, dtype=tf.float32) + shape_ = None + else: + initializer = tf.random_uniform_initializer(-self.params["embedding.init_scale"], + self.params["embedding.init_scale"], dtype=tf.float32) + shape_ = [self.source_vocab_info.total_size, self.params["embedding.dim"]] + + return tf.get_variable(name="W", shape=shape_, initializer=initializer, + trainable=self.params['embedding.tune']) @property @templatemethod("target_embedding") diff --git a/seq2seq/test/models_test.py b/seq2seq/test/models_test.py index a2009851..be386315 100644 --- a/seq2seq/test/models_test.py +++ b/seq2seq/test/models_test.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Tests for Models """ @@ -27,6 +28,7 @@ import tensorflow as tf from seq2seq.data import vocab, input_pipeline +from seq2seq.data.embeddings import read_embeddings from seq2seq.training import utils as training_utils from seq2seq.test import utils as test_utils from seq2seq.models import BasicSeq2Seq, AttentionSeq2Seq @@ -227,6 +229,33 @@ def create_model(self, mode, params=None): params_.update(params or {}) return AttentionSeq2Seq(params=params_, mode=mode) +class EmbeddingsFileTest(EncoderDecoderTests): + """ + Tests for using pre-trained embeddings as source embeddings. + """ + + def setUp(self): + super(TestBasicSeq2Seq, self).setUp() + vocab_list = ["Hello", "there"] + self.embeddings_file = test_utils.create_temporary_embeddings_file(vocab_list, dim=10) + + def create_model(self, mode, params=None): + params_ = BasicSeq2Seq.default_params().copy() + params_.update(TEST_PARAMS) + params_.update({ + "vocab_source": self.vocab_file.name, + "vocab_target": self.vocab_file.name, + "bridge.class": "PassThroughBridge", + "embedding.file": self.embeddings_file.name, + "embedding.tune": False + }) + params_.update(params or {}) + return BasicSeq2Seq(params=params_, mode=mode) + + def test_read_embeddings(self): + embeddings_mat = read_embeddings(self.embeddings_file.name) + assert embeddings_mat.shape[1] == 10 + if __name__ == "__main__": tf.test.main() diff --git a/seq2seq/test/utils.py b/seq2seq/test/utils.py index ffe936d9..eb27ec7e 100644 --- a/seq2seq/test/utils.py +++ b/seq2seq/test/utils.py @@ -21,6 +21,7 @@ import tempfile import tensorflow as tf +import numpy as np def create_temp_parallel_data(sources, targets): @@ -89,3 +90,23 @@ def create_temporary_vocab_file(words, counts=None): vocab_file.write("{}\t{}\n".format(token, count).encode("utf-8")) vocab_file.flush() return vocab_file + +def create_temporary_embeddings_file(words, dim=10): + """ + Creates a temporary Embeddings file. + + Args: + words: List of words in the vocabulary + dim: embeddings dimension + + Returns: + A temporary file object with one word and its vector(float values) per line, + each separated by a blank space. + """ + embed_file = tempfile.NamedTemporaryFile() + for token in words: + vec = " ".join(np.random.uniform(-0.25, 0.25, size=dim).tolist()) + embed_file.write((token+" "+vec + "\n").encode("utf-8")) + embed_file.flush() + return embed_file + From 65d55193600fbcf342d245e6eb9519d566988a7e Mon Sep 17 00:00:00 2001 From: pratapbhanu Date: Wed, 5 Jul 2017 14:31:37 +0200 Subject: [PATCH 09/12] fixing tests --- seq2seq/data/embeddings.py | 23 +++++++++++++++++------ seq2seq/test/models_test.py | 12 +++++++----- seq2seq/test/utils.py | 2 +- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/seq2seq/data/embeddings.py b/seq2seq/data/embeddings.py index f4e8ead0..b384dbdf 100644 --- a/seq2seq/data/embeddings.py +++ b/seq2seq/data/embeddings.py @@ -6,26 +6,37 @@ ''' import numpy as np from seq2seq.data import vocab - +import tensorflow as tf def read_embeddings(embeddings_path, vocab_path): - vocab_, _, _ = vocab.read_vocab(vocab_path) + """Reads embeddings file. + + Args: + embeddings_path: full path for the embeddings file, where embeddings file + contains word and its vector(float values) per line, separated by blank space. + vocab_path: full path for the vocab file, where each line contains a single + vocab word. + + Returns: + a 2d array where row index corresponds to the word index in the vocab (special + vocab and other unknown words are also included at their respective row index. + """ + vocab_, _, _ = vocab.read_vocab(vocab_path) word2vec = {} with open(embeddings_path, 'rt', encoding='utf-8') as vec_file: for line in vec_file: - parts = line.split(',') + parts = line.split(' ') word = parts[0] - + emb_dim = len(parts) - 1 if word not in vocab_: continue vec = parts[1:] word2vec[word] = vec unknown_words = [w for w in vocab_ if w not in word2vec] - emb_dim = len(word2vec.get(vocab_[0])) rnd_vecs = [np.random.uniform(-0.25, 0.25, size=emb_dim).tolist() for _ in unknown_words] - print("adding %d unknown words to vocab"%len(unknown_words)) + tf.logging.info("adding %d unknown words to vocab"%len(unknown_words)) word2vec.update(dict(zip(unknown_words,rnd_vecs))) vecs = [word2vec.get(w) for w in vocab_] diff --git a/seq2seq/test/models_test.py b/seq2seq/test/models_test.py index be386315..1f4e5bae 100644 --- a/seq2seq/test/models_test.py +++ b/seq2seq/test/models_test.py @@ -235,9 +235,10 @@ class EmbeddingsFileTest(EncoderDecoderTests): """ def setUp(self): - super(TestBasicSeq2Seq, self).setUp() - vocab_list = ["Hello", "there"] - self.embeddings_file = test_utils.create_temporary_embeddings_file(vocab_list, dim=10) + super(EmbeddingsFileTest, self).setUp() + self.emb_dim = 10 + self.embeddings_file = test_utils.create_temporary_embeddings_file(self.vocab_list, + dim=self.emb_dim) def create_model(self, mode, params=None): params_ = BasicSeq2Seq.default_params().copy() @@ -253,8 +254,9 @@ def create_model(self, mode, params=None): return BasicSeq2Seq(params=params_, mode=mode) def test_read_embeddings(self): - embeddings_mat = read_embeddings(self.embeddings_file.name) - assert embeddings_mat.shape[1] == 10 + embeddings_mat = read_embeddings(self.embeddings_file.name, self.vocab_info.path) + assert embeddings_mat.shape[1] == self.emb_dim, "Embeddings Dimension \ + should be %d but found %d"%(self.emb_dim, embeddings_mat.shape[1]) if __name__ == "__main__": diff --git a/seq2seq/test/utils.py b/seq2seq/test/utils.py index eb27ec7e..cbb45dbe 100644 --- a/seq2seq/test/utils.py +++ b/seq2seq/test/utils.py @@ -105,7 +105,7 @@ def create_temporary_embeddings_file(words, dim=10): """ embed_file = tempfile.NamedTemporaryFile() for token in words: - vec = " ".join(np.random.uniform(-0.25, 0.25, size=dim).tolist()) + vec = " ".join([str(x) for x in np.random.uniform(-0.25, 0.25, size=dim).tolist()]) embed_file.write((token+" "+vec + "\n").encode("utf-8")) embed_file.flush() return embed_file From 6403218f06c34d932a1c6fa06ec55c27733c1981 Mon Sep 17 00:00:00 2001 From: pratapbhanu Date: Wed, 5 Jul 2017 16:47:18 +0200 Subject: [PATCH 10/12] fixing pylint warnings --- seq2seq/data/embeddings.py | 32 ++++++++++++++++---------------- seq2seq/data/vocab.py | 9 ++++----- seq2seq/models/seq2seq_model.py | 14 ++++++++------ seq2seq/test/models_test.py | 9 +++++---- seq2seq/test/utils.py | 19 ++++++++++--------- 5 files changed, 43 insertions(+), 40 deletions(-) diff --git a/seq2seq/data/embeddings.py b/seq2seq/data/embeddings.py index b384dbdf..53078e9c 100644 --- a/seq2seq/data/embeddings.py +++ b/seq2seq/data/embeddings.py @@ -10,18 +10,20 @@ def read_embeddings(embeddings_path, vocab_path): """Reads embeddings file. - + Args: - embeddings_path: full path for the embeddings file, where embeddings file - contains word and its vector(float values) per line, separated by blank space. - vocab_path: full path for the vocab file, where each line contains a single - vocab word. - + embeddings_path: full path for the embeddings file, + where embeddings file contains word and its vector(float values) + per line, separated by blank space. + vocab_path: full path for the vocab file, + where each line contains a single vocab word. + Returns: - a 2d array where row index corresponds to the word index in the vocab (special - vocab and other unknown words are also included at their respective row index. + a 2d array where row index corresponds to the word index + in the vocab (special vocab and other unknown words are + also included at their respective row index. """ - vocab_, _, _ = vocab.read_vocab(vocab_path) + vocab_, _, _ = vocab.read_vocab(vocab_path) word2vec = {} with open(embeddings_path, 'rt', encoding='utf-8') as vec_file: for line in vec_file: @@ -32,16 +34,14 @@ def read_embeddings(embeddings_path, vocab_path): vec = parts[1:] word2vec[word] = vec - + unknown_words = [w for w in vocab_ if w not in word2vec] - rnd_vecs = [np.random.uniform(-0.25, 0.25, - size=emb_dim).tolist() for _ in unknown_words] - tf.logging.info("adding %d unknown words to vocab"%len(unknown_words)) - word2vec.update(dict(zip(unknown_words,rnd_vecs))) + rnd_vecs = [np.random.uniform(-0.25, 0.25, size=emb_dim) + .tolist() for _ in unknown_words] + tf.logging.info("adding %d unknown words to vocab", len(unknown_words)) + word2vec.update(dict(zip(unknown_words, rnd_vecs))) vecs = [word2vec.get(w) for w in vocab_] embedding_mat = np.asarray(vecs, dtype=np.float32) return embedding_mat - - diff --git a/seq2seq/data/vocab.py b/seq2seq/data/vocab.py index 3460470e..ffb2c0ec 100644 --- a/seq2seq/data/vocab.py +++ b/seq2seq/data/vocab.py @@ -67,14 +67,13 @@ def read_vocab(filename): Args: filename: Path to a vocabulary file containg one word per line. Each word is mapped to its line number. - + Returns: A tuple (vocab, counts, special_vocab) """ - tf.logging.info("Reading vocabulary from %s"%filename) - with gfile.GFile( # Load vocabulary into memory - filename) as file: - vocab = list(line.strip("\n") for line in file) + tf.logging.info("Reading vocabulary from %s", filename) + with gfile.GFile(filename) as file: + vocab = list(line.strip("\n") for line in file) vocab_size = len(vocab) has_counts = len(vocab[0].split("\t")) == 2 if has_counts: diff --git a/seq2seq/models/seq2seq_model.py b/seq2seq/models/seq2seq_model.py index db5ab150..ceb62820 100644 --- a/seq2seq/models/seq2seq_model.py +++ b/seq2seq/models/seq2seq_model.py @@ -47,11 +47,11 @@ def __init__(self, params, mode, name): self.target_vocab_info = None if "vocab_target" in self.params and self.params["vocab_target"]: self.target_vocab_info = vocab.get_vocab_info(self.params["vocab_target"]) - + self.embedding_mat = None if "embedding.file" in self.params and self.params["embedding.file"]: - embedding_mat = read_embeddings(self.params['embedding.file'], - self.source_vocab_info.path) + self.embedding_mat = read_embeddings(self.params['embedding.file'], + self.source_vocab_info.path) @staticmethod def default_params(): @@ -141,10 +141,12 @@ def source_embedding(self): initializer = tf.constant(self.embedding_mat, dtype=tf.float32) shape_ = None else: - initializer = tf.random_uniform_initializer(-self.params["embedding.init_scale"], - self.params["embedding.init_scale"], dtype=tf.float32) + initializer = tf.random_uniform_initializer( + -self.params["embedding.init_scale"], + self.params["embedding.init_scale"], + dtype=tf.float32) shape_ = [self.source_vocab_info.total_size, self.params["embedding.dim"]] - + return tf.get_variable(name="W", shape=shape_, initializer=initializer, trainable=self.params['embedding.tune']) diff --git a/seq2seq/test/models_test.py b/seq2seq/test/models_test.py index 1f4e5bae..1dac0bd5 100644 --- a/seq2seq/test/models_test.py +++ b/seq2seq/test/models_test.py @@ -237,8 +237,8 @@ class EmbeddingsFileTest(EncoderDecoderTests): def setUp(self): super(EmbeddingsFileTest, self).setUp() self.emb_dim = 10 - self.embeddings_file = test_utils.create_temporary_embeddings_file(self.vocab_list, - dim=self.emb_dim) + self.embeddings_file = test_utils.create_temp_embedding_file( + self.vocab_list, dim=self.emb_dim) def create_model(self, mode, params=None): params_ = BasicSeq2Seq.default_params().copy() @@ -254,9 +254,10 @@ def create_model(self, mode, params=None): return BasicSeq2Seq(params=params_, mode=mode) def test_read_embeddings(self): - embeddings_mat = read_embeddings(self.embeddings_file.name, self.vocab_info.path) + embeddings_mat = read_embeddings(self.embeddings_file.name, + self.vocab_info.path) assert embeddings_mat.shape[1] == self.emb_dim, "Embeddings Dimension \ - should be %d but found %d"%(self.emb_dim, embeddings_mat.shape[1]) + should be %d but found %d"%(self.emb_dim, embeddings_mat.shape[1]) if __name__ == "__main__": diff --git a/seq2seq/test/utils.py b/seq2seq/test/utils.py index cbb45dbe..c01a3818 100644 --- a/seq2seq/test/utils.py +++ b/seq2seq/test/utils.py @@ -90,23 +90,24 @@ def create_temporary_vocab_file(words, counts=None): vocab_file.write("{}\t{}\n".format(token, count).encode("utf-8")) vocab_file.flush() return vocab_file - -def create_temporary_embeddings_file(words, dim=10): + +def create_temp_embedding_file(words, dim=10): """ Creates a temporary Embeddings file. - + Args: - words: List of words in the vocabulary - dim: embeddings dimension - + words: List of words in the vocabulary + dim: embeddings dimension + Returns: - A temporary file object with one word and its vector(float values) per line, + A temporary file object with one word and + its vector(float values) per line, each separated by a blank space. """ embed_file = tempfile.NamedTemporaryFile() for token in words: - vec = " ".join([str(x) for x in np.random.uniform(-0.25, 0.25, size=dim).tolist()]) + vec = " ".join([str(x) for x in + np.random.uniform(-0.25, 0.25, size=dim).tolist()]) embed_file.write((token+" "+vec + "\n").encode("utf-8")) embed_file.flush() return embed_file - From ab9877122bc29a9b913916147b140661898d28d3 Mon Sep 17 00:00:00 2001 From: pratapbhanu Date: Wed, 5 Jul 2017 17:34:40 +0200 Subject: [PATCH 11/12] embeddings check condition in seq2seq_model --- seq2seq/models/seq2seq_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seq2seq/models/seq2seq_model.py b/seq2seq/models/seq2seq_model.py index ceb62820..6779f49f 100644 --- a/seq2seq/models/seq2seq_model.py +++ b/seq2seq/models/seq2seq_model.py @@ -136,7 +136,7 @@ def batch_size(self, features, labels): def source_embedding(self): """Returns the embedding used for the source sequence. """ - if self.embedding_mat: + if self.embedding_mat is not None: self.params.update({"embedding.dim":self.embedding_mat.shape[1]}) initializer = tf.constant(self.embedding_mat, dtype=tf.float32) shape_ = None From 78c396a595e012e97c96947212b6e1b3a2fd2d45 Mon Sep 17 00:00:00 2001 From: pratapbhanu Date: Wed, 5 Jul 2017 18:27:08 +0200 Subject: [PATCH 12/12] fix for python 2.7 --- seq2seq/data/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seq2seq/data/embeddings.py b/seq2seq/data/embeddings.py index 53078e9c..35869152 100644 --- a/seq2seq/data/embeddings.py +++ b/seq2seq/data/embeddings.py @@ -25,7 +25,7 @@ def read_embeddings(embeddings_path, vocab_path): """ vocab_, _, _ = vocab.read_vocab(vocab_path) word2vec = {} - with open(embeddings_path, 'rt', encoding='utf-8') as vec_file: + with open(embeddings_path, 'r') as vec_file: for line in vec_file: parts = line.split(' ') word = parts[0]