From 331daa8b7300493a18758cfe556fbe2b51f5079c Mon Sep 17 00:00:00 2001
From: Maxim Podkolzine <smartmaxim@gmail.com>
Date: Sat, 16 Dec 2017 13:30:49 +0100
Subject: [PATCH 1/2] Fix one-hot encoding

LT;DR: length calculation is wrong, padded zeros are never ignored.

Note that `vocab_encode` encodes the each char an index in `1`..`vocab_len`: that's what is stored in `seq` before it goes through one-hot encodding. It is expected that `tf.one_hot` will encode only valid indices and return zeros for paddings (which is `0`), but it's not what it does. Instead, it will encode every index in `0`..`vocab_len-1` and ignore `vocab_len`. This means that `}` char will always end the seq, while padded zeros are processed as normal chars.

Doing `seq - 1` fixes both the padding `0` (should be invalid) and `vocab_len` (should be valid) indices.
---
 examples/11_char_rnn_gist.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/11_char_rnn_gist.py b/examples/11_char_rnn_gist.py
index c51c099f..3d6bbc34 100644
--- a/examples/11_char_rnn_gist.py
+++ b/examples/11_char_rnn_gist.py
@@ -59,7 +59,7 @@ def create_rnn(seq, hidden_size=HIDDEN_SIZE):
     return output, in_state, out_state
 
 def create_model(seq, temp, vocab, hidden=HIDDEN_SIZE):
-    seq = tf.one_hot(seq, len(vocab))
+    seq = tf.one_hot(seq - 1, len(vocab))
     output, in_state, out_state = create_rnn(seq, hidden)
     # fully_connected is syntactic sugar for tf.matmul(w, output) + b
     # it will create w and b for us
@@ -120,4 +120,4 @@ def main():
     training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state)
     
 if __name__ == '__main__':
-    main()
\ No newline at end of file
+    main()

From 430d756f1838b423f60c104f55664e3df17690fe Mon Sep 17 00:00:00 2001
From: Maxim Podkolzine <smartmaxim@gmail.com>
Date: Sat, 16 Dec 2017 14:47:34 +0100
Subject: [PATCH 2/2] fix the sample generator as well

Since one-hot encoding shifts the index down by 1, the generator must account for that, otherwise the sample sequence will collapse to zeros
---
 examples/11_char_rnn_gist.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/examples/11_char_rnn_gist.py b/examples/11_char_rnn_gist.py
index 3d6bbc34..8e15e66a 100644
--- a/examples/11_char_rnn_gist.py
+++ b/examples/11_char_rnn_gist.py
@@ -103,7 +103,7 @@ def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed='
         if state is not None:
             feed.update({in_state: state})
         index, state = sess.run([sample, out_state], feed)
-        sentence += vocab_decode(index, vocab)
+        sentence += vocab_decode(index + 1, vocab)
     print(sentence)
 
 def main():