From 331daa8b7300493a18758cfe556fbe2b51f5079c Mon Sep 17 00:00:00 2001 From: Maxim Podkolzine <smartmaxim@gmail.com> Date: Sat, 16 Dec 2017 13:30:49 +0100 Subject: [PATCH 1/2] Fix one-hot encoding LT;DR: length calculation is wrong, padded zeros are never ignored. Note that `vocab_encode` encodes the each char an index in `1`..`vocab_len`: that's what is stored in `seq` before it goes through one-hot encodding. It is expected that `tf.one_hot` will encode only valid indices and return zeros for paddings (which is `0`), but it's not what it does. Instead, it will encode every index in `0`..`vocab_len-1` and ignore `vocab_len`. This means that `}` char will always end the seq, while padded zeros are processed as normal chars. Doing `seq - 1` fixes both the padding `0` (should be invalid) and `vocab_len` (should be valid) indices. --- examples/11_char_rnn_gist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/11_char_rnn_gist.py b/examples/11_char_rnn_gist.py index c51c099f..3d6bbc34 100644 --- a/examples/11_char_rnn_gist.py +++ b/examples/11_char_rnn_gist.py @@ -59,7 +59,7 @@ def create_rnn(seq, hidden_size=HIDDEN_SIZE): return output, in_state, out_state def create_model(seq, temp, vocab, hidden=HIDDEN_SIZE): - seq = tf.one_hot(seq, len(vocab)) + seq = tf.one_hot(seq - 1, len(vocab)) output, in_state, out_state = create_rnn(seq, hidden) # fully_connected is syntactic sugar for tf.matmul(w, output) + b # it will create w and b for us @@ -120,4 +120,4 @@ def main(): training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state) if __name__ == '__main__': - main() \ No newline at end of file + main() From 430d756f1838b423f60c104f55664e3df17690fe Mon Sep 17 00:00:00 2001 From: Maxim Podkolzine <smartmaxim@gmail.com> Date: Sat, 16 Dec 2017 14:47:34 +0100 Subject: [PATCH 2/2] fix the sample generator as well Since one-hot encoding shifts the index down by 1, the generator must account for that, otherwise the sample sequence will collapse to zeros --- examples/11_char_rnn_gist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/11_char_rnn_gist.py b/examples/11_char_rnn_gist.py index 3d6bbc34..8e15e66a 100644 --- a/examples/11_char_rnn_gist.py +++ b/examples/11_char_rnn_gist.py @@ -103,7 +103,7 @@ def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed=' if state is not None: feed.update({in_state: state}) index, state = sess.run([sample, out_state], feed) - sentence += vocab_decode(index, vocab) + sentence += vocab_decode(index + 1, vocab) print(sentence) def main():