diff --git a/examples/11_char_rnn_gist.py b/examples/11_char_rnn_gist.py index c51c099f..8e15e66a 100644 --- a/examples/11_char_rnn_gist.py +++ b/examples/11_char_rnn_gist.py @@ -59,7 +59,7 @@ def create_rnn(seq, hidden_size=HIDDEN_SIZE): return output, in_state, out_state def create_model(seq, temp, vocab, hidden=HIDDEN_SIZE): - seq = tf.one_hot(seq, len(vocab)) + seq = tf.one_hot(seq - 1, len(vocab)) output, in_state, out_state = create_rnn(seq, hidden) # fully_connected is syntactic sugar for tf.matmul(w, output) + b # it will create w and b for us @@ -103,7 +103,7 @@ def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed=' if state is not None: feed.update({in_state: state}) index, state = sess.run([sample, out_state], feed) - sentence += vocab_decode(index, vocab) + sentence += vocab_decode(index + 1, vocab) print(sentence) def main(): @@ -120,4 +120,4 @@ def main(): training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state) if __name__ == '__main__': - main() \ No newline at end of file + main()