diff --git a/seq2seq/contrib/seq2seq/helper.py b/seq2seq/contrib/seq2seq/helper.py index 977d0ab9..d42d5e21 100644 --- a/seq2seq/contrib/seq2seq/helper.py +++ b/seq2seq/contrib/seq2seq/helper.py @@ -32,8 +32,15 @@ import six -from tensorflow.contrib.distributions.python.ops import bernoulli -from tensorflow.contrib.distributions.python.ops import categorical +# from tensorflow.contrib.distributions.python.ops import bernoulli +# from tensorflow.contrib.distributions.python.ops import categorical +try: + from tensorflow.python.ops.distributions import bernoulli + from tensorflow.python.ops.distributions import categorical +except: + from tensorflow.contrib.distributions.python.ops import bernoulli + from tensorflow.contrib.distributions.python.ops import categorical + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import base as layers_base diff --git a/seq2seq/test/pipeline_test.py b/seq2seq/test/pipeline_test.py index 8456997b..1bca72fe 100644 --- a/seq2seq/test/pipeline_test.py +++ b/seq2seq/test/pipeline_test.py @@ -41,7 +41,10 @@ def _clear_flags(): """Resets Tensorflow's FLAG values""" #pylint: disable=W0212 - tf.app.flags.FLAGS = tf.app.flags._FlagValues() + #tf.app.flags.FLAGS = tf.app.flags._FlagValues() + attr_names = [] + for flag_key in dir(tf.app.flags.FLAGS): + delattr(tf.app.flags.FLAGS, flag_key) tf.app.flags._global_parser = argparse.ArgumentParser() diff --git a/seq2seq/training/utils.py b/seq2seq/training/utils.py index 9451d57c..86eabaf9 100644 --- a/seq2seq/training/utils.py +++ b/seq2seq/training/utils.py @@ -115,7 +115,9 @@ def cell_from_spec(cell_classname, cell_params): cell_class = locate(cell_classname) or getattr(rnn_cell, cell_classname) # Make sure additional arguments are valid - cell_args = set(inspect.getargspec(cell_class.__init__).args[1:]) + #cell_args = set(inspect.getargspec(cell_class.__init__).args[1:]) + cell_args = set(inspect.signature(cell_class.__init__).parameters) + for key in cell_params.keys(): if key not in cell_args: raise ValueError(