diff --git a/modules/hf_opt_module.py b/modules/hf_opt_module.py index 80819e55..b6f7a989 100644 --- a/modules/hf_opt_module.py +++ b/modules/hf_opt_module.py @@ -10,6 +10,11 @@ from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding from transformers.models.opt.configuration_opt import OPTConfig as GPTConfig +try: + from flash_attn.flash_attention import FlashAttention + flash_attn_installed = True +except ImportError: + flash_attn_installed = False def _make_causal_mask( input_ids_shape: torch.Size, @@ -167,6 +172,9 @@ def __init__( self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) + + if flash_attn_installed: + self.flash_attn = FlashAttention(attention_dropout = dropout) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -188,8 +196,10 @@ def forward( bsz, tgt_len, _ = hidden_states.size() + use_flash_attn = not is_cross_attention and flash_attn_installed + # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states) * self.scaling # B S H # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions @@ -205,10 +215,14 @@ def forward( value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: + elif not use_flash_attn: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + else: + # self attention with flash attention + key_states = self.k_proj(hidden_states) # B S H + value_states = self.v_proj(hidden_states) # B S H if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -220,75 +234,91 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + if not is_cross_attention and not self.is_decoder and flash_attn_installed: + qkv = torch.stack( + [ + torch.view(query_states, (bsz, tgt_len, self.num_heads, self.head_dim)), + torch.view(key_states, (bsz, tgt_len, self.num_heads, self.head_dim)), + torch.view(value_states, (bsz, tgt_len, self.num_heads, self.head_dim)), + ], + dim=2 + ) - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + out, _ = self.flash_attn(qkv, causal=True) # assuming that these are autoregressive!! - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) + out = torch.reshape(out, (bsz, tgt_len, self.embed_dim)) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - dtype_attn_weights = attn_weights.dtype - - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if dtype_attn_weights == torch.float16: - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) + return attn_output, None, None else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + dtype_attn_weights = attn_weights.dtype + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if dtype_attn_weights == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None - attn_output = torch.bmm(attn_probs, value_states) + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) + attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, past_key_value class GPTBlock(OPTDecoderLayer):