Skip to content
This repository was archived by the owner on Dec 29, 2022. It is now read-only.
Open
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ generated-members=set_shape,np.float32
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
contextmanager-decorators=contextlib.contextmanager,tensorflow.python.util.tf_contextlib.contextmanager


[VARIABLES]
Expand Down
9 changes: 7 additions & 2 deletions seq2seq/contrib/seq2seq/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@

import six

from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
try:
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
except:
# Backwards compatibility with TensorFlow prior to 1.2.
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import base as layers_base
Expand Down
47 changes: 47 additions & 0 deletions seq2seq/data/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
'''
Created on Jul 4, 2017

@author: Bhanu

'''
import numpy as np
from seq2seq.data import vocab
import tensorflow as tf

def read_embeddings(embeddings_path, vocab_path):
"""Reads embeddings file.

Args:
embeddings_path: full path for the embeddings file,
where embeddings file contains word and its vector(float values)
per line, separated by blank space.
vocab_path: full path for the vocab file,
where each line contains a single vocab word.

Returns:
a 2d array where row index corresponds to the word index
in the vocab (special vocab and other unknown words are
also included at their respective row index.
"""
vocab_, _, _ = vocab.read_vocab(vocab_path)
word2vec = {}
with open(embeddings_path, 'r') as vec_file:
for line in vec_file:
parts = line.split(' ')
word = parts[0]
emb_dim = len(parts) - 1
if word not in vocab_: continue

vec = parts[1:]
word2vec[word] = vec

unknown_words = [w for w in vocab_ if w not in word2vec]
rnd_vecs = [np.random.uniform(-0.25, 0.25, size=emb_dim)
.tolist() for _ in unknown_words]
tf.logging.info("adding %d unknown words to vocab", len(unknown_words))
word2vec.update(dict(zip(unknown_words, rnd_vecs)))

vecs = [word2vec.get(w) for w in vocab_]
embedding_mat = np.asarray(vecs, dtype=np.float32)

return embedding_mat
2 changes: 2 additions & 0 deletions seq2seq/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import six

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
# pylint: enable=no-name-in-module

from seq2seq.configurable import Configurable
from seq2seq.data import split_tokens_decoder, parallel_data_provider
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/parallel_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import numpy as np

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_provider
from tensorflow.contrib.slim.python.slim.data import parallel_reader
# pylint: enable=no-name-in-module

from seq2seq.data import split_tokens_decoder

Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/sequence_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""A decoder for tf.SequenceExample"""

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_decoder
# pylint: enable=no-name-in-module


class TFSEquenceExampleDecoder(data_decoder.DataDecoder):
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/split_tokens_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from __future__ import unicode_literals

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_decoder
# pylint: enable=no-name-in-module


class SplitTokensDecoder(data_decoder.DataDecoder):
Expand Down
46 changes: 29 additions & 17 deletions seq2seq/data/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,44 +60,56 @@ def get_special_vocab(vocabulary_size):
return SpecialVocab(*range(vocabulary_size, vocabulary_size + 3))


def create_vocabulary_lookup_table(filename, default_value=None):
"""Creates a lookup table for a vocabulary file.

def read_vocab(filename):
"""Reads vocab file into the memory and adds special-vocab to it.

Args:
filename: Path to a vocabulary file containg one word per line.
Each word is mapped to its line number.
default_value: UNK tokens will be mapped to this id.
If None, UNK tokens will be mapped to [vocab_size]
Each word is mapped to its line number.

Returns:
A tuple (vocab_to_id_table, id_to_vocab_table,
word_to_count_table, vocab_size). The vocab size does not include
the UNK token.
"""
if not gfile.Exists(filename):
raise ValueError("File does not exist: {}".format(filename))

# Load vocabulary into memory
Returns:
A tuple (vocab, counts, special_vocab)
"""
tf.logging.info("Reading vocabulary from %s", filename)
with gfile.GFile(filename) as file:
vocab = list(line.strip("\n") for line in file)
vocab_size = len(vocab)

has_counts = len(vocab[0].split("\t")) == 2
if has_counts:
vocab, counts = zip(*[_.split("\t") for _ in vocab])
counts = [float(_) for _ in counts]
vocab = list(vocab)
else:
counts = [-1. for _ in vocab]

# Add special vocabulary items
special_vocab = get_special_vocab(vocab_size)
vocab += list(special_vocab._fields)
vocab_size += len(special_vocab)
counts += [-1. for _ in list(special_vocab._fields)]

return vocab, counts, special_vocab

def create_vocabulary_lookup_table(filename, default_value=None):
"""Creates a lookup table for a vocabulary file.

Args:
filename: Path to a vocabulary file containg one word per line.
Each word is mapped to its line number.
default_value: UNK tokens will be mapped to this id.
If None, UNK tokens will be mapped to [vocab_size]

Returns:
A tuple (vocab_to_id_table, id_to_vocab_table,
word_to_count_table, vocab_size). The vocab size does not include
the UNK token.
"""
if not gfile.Exists(filename):
raise ValueError("File does not exist: {}".format(filename))

vocab, counts, special_vocab = read_vocab(filename)
if default_value is None:
default_value = special_vocab.UNK
vocab_size = len(vocab)

tf.logging.info("Creating vocabulary lookup table of size %d", vocab_size)

Expand Down
2 changes: 2 additions & 0 deletions seq2seq/encoders/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from __future__ import print_function

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.nets.inception_v3 \
import inception_v3_base
# pylint: enable=no-name-in-module

from seq2seq.encoders.encoder import Encoder, EncoderOutput

Expand Down
3 changes: 1 addition & 2 deletions seq2seq/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import copy
import tensorflow as tf
from tensorflow.contrib.rnn.python.ops import rnn

from seq2seq.encoders.encoder import Encoder, EncoderOutput
from seq2seq.training import utils as training_utils
Expand Down Expand Up @@ -186,7 +185,7 @@ def encode(self, inputs, sequence_length, **kwargs):
cells_fw = _unpack_cell(cell_fw)
cells_bw = _unpack_cell(cell_bw)

result = rnn.stack_bidirectional_dynamic_rnn(
result = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
cells_fw=cells_fw,
cells_bw=cells_bw,
inputs=inputs,
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/metrics/metric_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import tensorflow as tf
from tensorflow.contrib import metrics
# pylint: disable=no-name-in-module
from tensorflow.contrib.learn import MetricSpec
# pylint: enable=no-name-in-module

from seq2seq.data import postproc
from seq2seq.configurable import Configurable
Expand Down
27 changes: 21 additions & 6 deletions seq2seq/models/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from seq2seq.decoders.beam_search_decoder import BeamSearchDecoder
from seq2seq.inference import beam_search
from seq2seq.models.model_base import ModelBase, _flatten_dict
from seq2seq.data.embeddings import read_embeddings


class Seq2SeqModel(ModelBase):
Expand All @@ -47,6 +48,11 @@ def __init__(self, params, mode, name):
if "vocab_target" in self.params and self.params["vocab_target"]:
self.target_vocab_info = vocab.get_vocab_info(self.params["vocab_target"])

self.embedding_mat = None
if "embedding.file" in self.params and self.params["embedding.file"]:
self.embedding_mat = read_embeddings(self.params['embedding.file'],
self.source_vocab_info.path)

@staticmethod
def default_params():
params = ModelBase.default_params()
Expand All @@ -57,6 +63,8 @@ def default_params():
"embedding.dim": 100,
"embedding.init_scale": 0.04,
"embedding.share": False,
"embedding.file": None,
"embedding.tune": True,
"inference.beam_search.beam_width": 0,
"inference.beam_search.length_penalty_weight": 0.0,
"inference.beam_search.choose_successors_fn": "choose_top_k",
Expand Down Expand Up @@ -128,12 +136,19 @@ def batch_size(self, features, labels):
def source_embedding(self):
"""Returns the embedding used for the source sequence.
"""
return tf.get_variable(
name="W",
shape=[self.source_vocab_info.total_size, self.params["embedding.dim"]],
initializer=tf.random_uniform_initializer(
-self.params["embedding.init_scale"],
self.params["embedding.init_scale"]))
if self.embedding_mat is not None:
self.params.update({"embedding.dim":self.embedding_mat.shape[1]})
initializer = tf.constant(self.embedding_mat, dtype=tf.float32)
shape_ = None
else:
initializer = tf.random_uniform_initializer(
-self.params["embedding.init_scale"],
self.params["embedding.init_scale"],
dtype=tf.float32)
shape_ = [self.source_vocab_info.total_size, self.params["embedding.dim"]]

return tf.get_variable(name="W", shape=shape_, initializer=initializer,
trainable=self.params['embedding.tune'])

@property
@templatemethod("target_embedding")
Expand Down
14 changes: 10 additions & 4 deletions seq2seq/test/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@ class TestPrintModelAnalysisHook(tf.test.TestCase):
def test_begin(self):
model_dir = tempfile.mkdtemp()
outfile = tempfile.NamedTemporaryFile()
tf.get_variable("weigths", [128, 128])
tf.get_variable("weights", [128, 128])
hook = hooks.PrintModelAnalysisHook(
params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig())
hook.begin()

with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file:
file_contents = file.read().strip()

self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n"
" weigths (128x128, 16.38k/16.38k params)")
lines = tf.compat.as_text(file_contents).split("\n")
if len(lines) == 3:
# TensorFlow v1.2 includes an extra header line
self.assertEqual(lines[0], "node name | # parameters")

self.assertEqual(lines[-2], "_TFProfRoot (--/16.38k params)")
self.assertEqual(lines[-1], " weights (128x128, 16.38k/16.38k params)")

outfile.close()


Expand Down Expand Up @@ -125,7 +131,7 @@ def tearDown(self):
def test_capture(self):
global_step = tf.contrib.framework.get_or_create_global_step()
# Some test computation
some_weights = tf.get_variable("weigths", [2, 128])
some_weights = tf.get_variable("weights", [2, 128])
computation = tf.nn.softmax(some_weights)

hook = hooks.MetadataCaptureHook(
Expand Down
32 changes: 32 additions & 0 deletions seq2seq/test/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Models
"""

Expand All @@ -27,6 +28,7 @@
import tensorflow as tf

from seq2seq.data import vocab, input_pipeline
from seq2seq.data.embeddings import read_embeddings
from seq2seq.training import utils as training_utils
from seq2seq.test import utils as test_utils
from seq2seq.models import BasicSeq2Seq, AttentionSeq2Seq
Expand Down Expand Up @@ -227,6 +229,36 @@ def create_model(self, mode, params=None):
params_.update(params or {})
return AttentionSeq2Seq(params=params_, mode=mode)

class EmbeddingsFileTest(EncoderDecoderTests):
"""
Tests for using pre-trained embeddings as source embeddings.
"""

def setUp(self):
super(EmbeddingsFileTest, self).setUp()
self.emb_dim = 10
self.embeddings_file = test_utils.create_temp_embedding_file(
self.vocab_list, dim=self.emb_dim)

def create_model(self, mode, params=None):
params_ = BasicSeq2Seq.default_params().copy()
params_.update(TEST_PARAMS)
params_.update({
"vocab_source": self.vocab_file.name,
"vocab_target": self.vocab_file.name,
"bridge.class": "PassThroughBridge",
"embedding.file": self.embeddings_file.name,
"embedding.tune": False
})
params_.update(params or {})
return BasicSeq2Seq(params=params_, mode=mode)

def test_read_embeddings(self):
embeddings_mat = read_embeddings(self.embeddings_file.name,
self.vocab_info.path)
assert embeddings_mat.shape[1] == self.emb_dim, "Embeddings Dimension \
should be %d but found %d"%(self.emb_dim, embeddings_mat.shape[1])


if __name__ == "__main__":
tf.test.main()
Loading