1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from collections import OrderedDict
1617from typing import Unpack
1718
1819import torch
1920import torch .nn as nn
2021import transformer_engine .pytorch
2122import transformers
23+ from transformer_engine .pytorch .attention import InferenceParams
24+ from transformer_engine .pytorch .attention .rope import RotaryPositionEmbedding
2225from transformers import LlamaConfig , PreTrainedModel
2326from transformers .modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
24- from transformers .modeling_rope_utils import dynamic_rope_update
2527from transformers .models .llama .modeling_llama import LlamaRotaryEmbedding
2628from transformers .utils .generic import TransformersKwargs
2729
2830
29- class NVLlamaConfig (LlamaConfig ): ... # noqa: D101
31+ class NVLlamaConfig (LlamaConfig ):
32+ """NVLlama configuration."""
33+
34+ attn_input_format : str = "bshd"
35+ self_attn_mask_type : str = "padding_causal"
3036
3137
3238class NVLlamaPreTrainedModel (PreTrainedModel ):
@@ -62,7 +68,8 @@ def __init__(self, config: LlamaConfig):
6268 qkv_weight_interleaved = True ,
6369 normalization = "RMSNorm" ,
6470 activation = "swiglu" ,
65- attn_input_format = "bshd" ,
71+ attn_input_format = config .attn_input_format ,
72+ self_attn_mask_type = config .self_attn_mask_type ,
6673 num_gqa_groups = config .num_key_value_heads ,
6774 layer_number = layer_idx + 1 ,
6875 params_dtype = config .dtype ,
@@ -71,7 +78,12 @@ def __init__(self, config: LlamaConfig):
7178 ]
7279 )
7380 self .norm = transformer_engine .pytorch .RMSNorm (config .hidden_size , eps = config .rms_norm_eps , dtype = config .dtype )
74- self .rotary_emb = NVLlamaRotaryEmbedding (config = config )
81+
82+ # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
83+ # LlamaRotaryEmbedding.
84+ self .rotary_emb = RotaryPositionEmbedding (config .hidden_size // config .num_attention_heads )
85+ self .rotary_emb .inv_freq = LlamaRotaryEmbedding (config = config ).inv_freq
86+
7587 self .gradient_checkpointing = False
7688
7789 # Initialize weights and apply final processing
@@ -82,9 +94,8 @@ def forward(
8294 input_ids : torch .Tensor | None = None ,
8395 attention_mask : torch .Tensor | None = None ,
8496 position_ids : torch .Tensor | None = None ,
85- past_key_values : tuple [ tuple [ torch . Tensor , ...], ...] | None = None ,
97+ past_key_values : InferenceParams | None = None ,
8698 inputs_embeds : torch .Tensor | None = None ,
87- cache_position : torch .Tensor | None = None ,
8899 use_cache : bool | None = None ,
89100 ** kwargs : Unpack [TransformersKwargs ],
90101 ) -> BaseModelOutputWithPast :
@@ -96,7 +107,6 @@ def forward(
96107 position_ids (torch.Tensor): The position ids.
97108 past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values.
98109 inputs_embeds (torch.Tensor): The inputs embeds.
99- cache_position (torch.Tensor): The cache position.
100110 use_cache (bool): Whether to use cache.
101111 **kwargs: Additional keyword arguments.
102112
@@ -112,34 +122,64 @@ def forward(
112122 if inputs_embeds is None :
113123 inputs_embeds : torch .Tensor = self .embed_tokens (input_ids )
114124
115- if use_cache and past_key_values is None :
116- past_key_values = transformers .cache_utils .DynamicCache (config = self .config )
125+ hidden_states = inputs_embeds
126+ if self .config .attn_input_format == "bshd" :
127+ if past_key_values is not None :
128+ max_seq_len = past_key_values .max_sequence_length
129+ else :
130+ max_seq_len = hidden_states .shape [1 ]
131+ te_rope_emb = self .rotary_emb (max_seq_len = max_seq_len )
132+ elif self .config .attn_input_format == "thd" :
133+ te_rope_emb = self .rotary_emb (max_seq_len = kwargs ["cu_seq_lens_q" ][- 1 ])
134+
135+ has_thd_input = [
136+ x is not None
137+ for x in [
138+ kwargs .get ("cu_seq_lens_q" , None ),
139+ kwargs .get ("cu_seq_lens_k" , None ),
140+ kwargs .get ("max_length_q" , None ),
141+ kwargs .get ("max_length_k" , None ),
142+ ]
143+ ]
117144
118- if cache_position is None :
119- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
120- cache_position : torch .Tensor = torch .arange (
121- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
145+ if isinstance (past_key_values , InferenceParams ):
146+ # lengths = attention_mask.sum(dim=1) if attention_mask is not None else torch.tensor([0])
147+ lengths = input_ids .ne (0 ).sum (dim = 1 ) if input_ids is not None else torch .tensor ([0 ])
148+ past_key_values .pre_step (OrderedDict (zip (list (range (len (lengths ))), lengths .tolist ())))
149+
150+ if self .config .attn_input_format == "thd" :
151+ if not all (has_thd_input ):
152+ raise ValueError (
153+ "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
154+ )
155+ assert hidden_states .dim () == 3 and hidden_states .size (0 ) == 1 , (
156+ "THD expects embeddings shaped [1, total_tokens, hidden_size]."
122157 )
158+ hidden_states = hidden_states .squeeze (0 )
159+ attention_mask = None
123160
124- if position_ids is None :
125- position_ids = cache_position .unsqueeze (0 )
161+ elif self .config .attn_input_format == "bshd" and any (has_thd_input ):
162+ raise ValueError (
163+ "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
164+ )
126165
127- hidden_states = inputs_embeds
128- position_embeddings = self .rotary_emb (hidden_states , position_ids )
166+ # Construct the appropriate attention mask.
167+ if attention_mask is not None and self .config .self_attn_mask_type == "padding_causal" :
168+ attention_mask = ~ attention_mask .to (bool )[:, None , None , :]
129169
130170 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
131171 if output_hidden_states :
132172 all_hidden_states = (* all_hidden_states , hidden_states )
133173
134174 hidden_states = decoder_layer (
135175 hidden_states ,
136- attention_mask = None ,
137- self_attn_mask_type = "causal" ,
138- rotary_pos_emb = position_embeddings ,
139- # position_ids=position_ids ,
140- # past_key_values=past_key_values ,
141- # cache_position=cache_position ,
142- # ** kwargs,
176+ attention_mask = attention_mask ,
177+ rotary_pos_emb = te_rope_emb ,
178+ inference_params = past_key_values ,
179+ cu_seqlens_q = kwargs . get ( "cu_seq_lens_q" , None ) ,
180+ cu_seqlens_kv = kwargs . get ( "cu_seq_lens_k" , None ) ,
181+ max_seqlen_q = kwargs . get ( "max_length_q" , None ) ,
182+ max_seqlen_kv = kwargs . get ( "max_length_k" , None ) ,
143183 )
144184
145185 hidden_states = self .norm (hidden_states )
@@ -185,7 +225,7 @@ def forward(
185225 labels : torch .Tensor | None = None ,
186226 use_cache : bool | None = None ,
187227 cache_position : torch .Tensor | None = None ,
188- logits_to_keep : int | torch . Tensor = 0 ,
228+ only_keep_last_logits : bool = False ,
189229 ** kwargs : Unpack [TransformersKwargs ],
190230 ) -> CausalLMOutputWithPast :
191231 """Forward pass for the NVLlamaForCausalLM model.
@@ -199,7 +239,8 @@ def forward(
199239 labels (torch.Tensor): The labels.
200240 use_cache (bool): Whether to use cache.
201241 cache_position (torch.Tensor): The cache position.
202- logits_to_keep (int | torch.Tensor): The logits to keep.
242+ only_keep_last_logits (bool): Whether to keep only the last logits, as a workaround for the fact that TE
243+ doesn't support left-side padding with `padding_causal` attention masks.
203244 **kwargs: Additional keyword arguments.
204245
205246 Returns:
@@ -217,9 +258,26 @@ def forward(
217258 )
218259
219260 hidden_states = outputs .last_hidden_state
220- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
221- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
222- logits = self .lm_head (hidden_states [:, slice_indices , :])
261+
262+ # TE doesn't support left-side padding with `padding_causal` attention masks, and InferenceParams doesn't
263+ # support arbitrary attention masks (and the attention backend for arbitrary masks is the slower, unfused
264+ # backend). To keep the inference interface consistent with HF's `GenerationMixin.generate` interface, we use a
265+ # `only_keep_last_logits` flag to indicate that we should pick out and return only the last token's hidden state
266+ # during pre-fill. This allows generation to work with right-side padding. Note, make sure that you decode the
267+ # tokens with `skip_special_tokens=True` when using this flag, otherwise padding tokens will interrupt the
268+ # generated text.
269+ if (
270+ only_keep_last_logits
271+ and attention_mask is not None # Padded inputs
272+ and hidden_states .shape [1 ] > 1 # We're in pre-fill mode
273+ ):
274+ seqlens = attention_mask .sum (dim = 1 ) # shape: [batch]
275+ # For each batch idx, select hidden_states[idx, seqlens[idx]-1, :]
276+ batch_indices = torch .arange (hidden_states .size (0 ), device = hidden_states .device )
277+ selected_hidden_states = hidden_states [batch_indices , seqlens - 1 , :] # shape: [batch, hidden_dim]
278+ hidden_states = selected_hidden_states .unsqueeze (1 ) # shape: [batch, 1, hidden_dim]
279+
280+ logits = self .lm_head (hidden_states )
223281
224282 loss = None
225283 if labels is not None :
@@ -248,25 +306,3 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
248306class NVLlamaForTokenClassification ( # noqa: D101
249307 transformers .modeling_layers .GenericForTokenClassification , NVLlamaPreTrainedModel
250308): ...
251-
252-
253- class NVLlamaRotaryEmbedding (LlamaRotaryEmbedding ):
254- """Slight modification of the LlamaRotaryEmbedding for use with Transformer Engine."""
255-
256- @torch .no_grad ()
257- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
258- def forward (self , x , position_ids ): # pyright: ignore[reportIncompatibleMethodOverride]
259- """Forward pass for the NVLlamaRotaryEmbedding.
260-
261- Unlike the original LlamaRotaryEmbedding, this implementation returns the frequency tensor (upstream of the
262- cosine and sine transforms), reshaped in a way that is compatible with TransformerEngine's fused RoPE.
263- """
264- inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
265- position_ids_expanded = position_ids [:, None , :].float ()
266-
267- device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
268- with torch .autocast (device_type = device_type , enabled = False ): # Force float32
269- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
270- emb = torch .cat ((freqs , freqs ), dim = - 1 )
271-
272- return emb .transpose (0 , 1 ).unsqueeze (1 )
0 commit comments