Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 87 additions & 57 deletions modules/hf_opt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding
from transformers.models.opt.configuration_opt import OPTConfig as GPTConfig

try:
from flash_attn.flash_attention import FlashAttention
flash_attn_installed = True
except ImportError:
flash_attn_installed = False

def _make_causal_mask(
input_ids_shape: torch.Size,
Expand Down Expand Up @@ -167,6 +172,9 @@ def __init__(
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)

if flash_attn_installed:
self.flash_attn = FlashAttention(attention_dropout = dropout)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -188,8 +196,10 @@ def forward(

bsz, tgt_len, _ = hidden_states.size()

use_flash_attn = not is_cross_attention and flash_attn_installed

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
query_states = self.q_proj(hidden_states) * self.scaling # B S H
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
Expand All @@ -205,10 +215,14 @@ def forward(
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
elif not use_flash_attn:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
else:
# self attention with flash attention
key_states = self.k_proj(hidden_states) # B S H
value_states = self.v_proj(hidden_states) # B S H

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
Expand All @@ -220,75 +234,91 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
if not is_cross_attention and not self.is_decoder and flash_attn_installed:
qkv = torch.stack(
[
torch.view(query_states, (bsz, tgt_len, self.num_heads, self.head_dim)),
torch.view(key_states, (bsz, tgt_len, self.num_heads, self.head_dim)),
torch.view(value_states, (bsz, tgt_len, self.num_heads, self.head_dim)),
],
dim=2
)

src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
out, _ = self.flash_attn(qkv, causal=True) # assuming that these are autoregressive!!

if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
out = torch.reshape(out, (bsz, tgt_len, self.embed_dim))

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
dtype_attn_weights = attn_weights.dtype

# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if dtype_attn_weights == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
return attn_output, None, None
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)

if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None

attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
dtype_attn_weights = attn_weights.dtype

# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if dtype_attn_weights == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None

attn_output = torch.bmm(attn_probs, value_states)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = torch.bmm(attn_probs, value_states)

if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights_reshaped, past_key_value
return attn_output, attn_weights_reshaped, past_key_value


class GPTBlock(OPTDecoderLayer):
Expand Down