Skip to content

what does the variable “is_recompute” designed for ? #16

@Ga-Lee

Description

@Ga-Lee

Hi, thank you for your contribution to the community. I have learned a lot from your wonderful work!

After reading your codes below

is_recompute = current_end <= kv_cache["global_end_index"].item() and current_start > 0

I would like to ask you about some questions about KV cache update in causal_model.py, they looks a little similar to Self-Forcing, but actually quite different with modified design.

Especially, I would like to know what the variable is_recompute is designed for ?

I know torch.utils.checkpoint.checkpoint is required for training, but directly applying it seems may cause the tensor mismatch between the original forward process and activation recomputing process. For example, I tried to update the kv cache within the self attention module without using the cache_update_info, then got errors like torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.

I wonder could you please briefly explain your design to solve (1) the KV cache over-writing conflict across timesteps and (2) potential tensor shape mismatch during gradient checkpoint activation recomputing for me? Your answer will mean a lot to me !

Thank you very much for your kind reply!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions