Currently, `decode` is hardcoded to `False` in `train_dynamics.py`. We should instead add an arg to conditionally enable kv caching. We should also change the naming. There is a separate open issue for this: https://github.com/p-doom/jasmine/issues/191