Hi, thank you for your contribution to the community. I have learned a lot from your wonderful work!
After reading your codes below
|
is_recompute = current_end <= kv_cache["global_end_index"].item() and current_start > 0 |
I would like to ask you about some questions about KV cache update in causal_model.py, they looks a little similar to Self-Forcing, but actually quite different with modified design.
Especially, I would like to know what the variable is_recompute is designed for ?
I know torch.utils.checkpoint.checkpoint is required for training, but directly applying it seems may cause the tensor mismatch between the original forward process and activation recomputing process. For example, I tried to update the kv cache within the self attention module without using the cache_update_info, then got errors like torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
I wonder could you please briefly explain your design to solve (1) the KV cache over-writing conflict across timesteps and (2) potential tensor shape mismatch during gradient checkpoint activation recomputing for me? Your answer will mean a lot to me !
Thank you very much for your kind reply!
Hi, thank you for your contribution to the community. I have learned a lot from your wonderful work!
After reading your codes below
LongLive/wan/modules/causal_model.py
Line 229 in a4a0c87
I would like to ask you about some questions about KV cache update in
causal_model.py, they looks a little similar to Self-Forcing, but actually quite different with modified design.Especially, I would like to know what the variable
is_recomputeis designed for ?I know torch.utils.checkpoint.checkpoint is required for training, but directly applying it seems may cause the tensor mismatch between the original forward process and activation recomputing process. For example, I tried to update the kv cache within the self attention module without using the cache_update_info, then got errors like
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.I wonder could you please briefly explain your design to solve (1) the KV cache over-writing conflict across timesteps and (2) potential tensor shape mismatch during gradient checkpoint activation recomputing for me? Your answer will mean a lot to me !
Thank you very much for your kind reply!