2121import transformer_engine .pytorch
2222import transformers
2323from transformer_engine .pytorch .attention import InferenceParams
24+ from transformer_engine .pytorch .attention .inference import PagedKVCacheManager
2425from transformer_engine .pytorch .attention .rope import RotaryPositionEmbedding
2526from transformers import LlamaConfig , PreTrainedModel
2627from transformers .modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
3132class NVLlamaConfig (LlamaConfig ):
3233 """NVLlama configuration."""
3334
34- attn_input_format : str = "bshd"
35- self_attn_mask_type : str = "padding_causal"
36-
3735
3836class NVLlamaPreTrainedModel (PreTrainedModel ):
3937 """Base class for NVLlama models."""
@@ -68,8 +66,8 @@ def __init__(self, config: LlamaConfig):
6866 qkv_weight_interleaved = True ,
6967 normalization = "RMSNorm" ,
7068 activation = "swiglu" ,
71- attn_input_format = config . attn_input_format ,
72- self_attn_mask_type = config . self_attn_mask_type ,
69+ attn_input_format = "thd" ,
70+ self_attn_mask_type = "padding_causal" ,
7371 num_gqa_groups = config .num_key_value_heads ,
7472 layer_number = layer_idx + 1 ,
7573 params_dtype = config .dtype ,
@@ -123,71 +121,76 @@ def forward(
123121 inputs_embeds : torch .Tensor = self .embed_tokens (input_ids )
124122
125123 hidden_states = inputs_embeds
126- if self .config .attn_input_format == "bshd" :
127- if past_key_values is not None :
128- max_seq_len = past_key_values .max_sequence_length
129- else :
130- max_seq_len = hidden_states .shape [1 ]
131- te_rope_emb = self .rotary_emb (max_seq_len = max_seq_len )
132- elif self .config .attn_input_format == "thd" :
133- te_rope_emb = self .rotary_emb (max_seq_len = kwargs ["cu_seq_lens_q" ][- 1 ])
134-
135- has_thd_input = [
136- x is not None
137- for x in [
138- kwargs .get ("cu_seq_lens_q" , None ),
139- kwargs .get ("cu_seq_lens_k" , None ),
140- kwargs .get ("max_length_q" , None ),
141- kwargs .get ("max_length_k" , None ),
142- ]
143- ]
144124
145- if isinstance (past_key_values , InferenceParams ):
146- # lengths = attention_mask.sum(dim=1) if attention_mask is not None else torch.tensor([0])
147- lengths = input_ids .ne (0 ).sum (dim = 1 ) if input_ids is not None else torch .tensor ([0 ])
148- past_key_values .pre_step (OrderedDict (zip (list (range (len (lengths ))), lengths .tolist ())))
149-
150- if self .config .attn_input_format == "thd" :
151- if not all (has_thd_input ):
152- raise ValueError (
153- "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
154- )
125+ has_thd_input = [x in kwargs for x in ["cu_seq_lens_q" , "cu_seq_lens_k" , "max_length_q" , "max_length_k" ]]
126+ should_pack_inputs = not any (has_thd_input )
127+
128+ # This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
129+ # attention backend.
130+ if should_pack_inputs :
131+ # Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
132+ # to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
133+ # entire transformer stack run in THD mode.
134+ assert attention_mask is not None , "Attention mask is required when using BSHD inputs."
135+ batch_size = hidden_states .size (0 )
136+ hidden_states , indices , cu_seqlens , max_seqlen , _ = _unpad_input (hidden_states , attention_mask )
137+ cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
138+ max_length_q = max_length_k = max_seqlen
139+
140+ else :
141+ # Here, we're providing THD-style inputs, so we can just grab the kwargs.
155142 assert hidden_states .dim () == 3 and hidden_states .size (0 ) == 1 , (
156143 "THD expects embeddings shaped [1, total_tokens, hidden_size]."
157144 )
158145 hidden_states = hidden_states .squeeze (0 )
159- attention_mask = None
146+ cu_seq_lens_q = kwargs ["cu_seq_lens_q" ]
147+ cu_seq_lens_k = kwargs ["cu_seq_lens_k" ]
148+ max_length_q = kwargs ["max_length_q" ]
149+ max_length_k = kwargs ["max_length_k" ]
150+
151+ # If we're using kv-caching, we can't trust the max_length_q value as the true max length for rotary
152+ # embeddings, since this will be 1 in generation. Instead we can take the max sequence length from the past
153+ # key values object.
154+ te_rope_emb = self .rotary_emb (
155+ max_seq_len = max_length_q if past_key_values is None else past_key_values .max_ctx_len
156+ )
160157
161- elif self .config .attn_input_format == "bshd" and any (has_thd_input ):
162- raise ValueError (
163- "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
158+ if isinstance (past_key_values , InferenceParams ):
159+ # In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
160+ # compute the lengths of each sequence in the batch.
161+ lengths = (
162+ attention_mask .sum (dim = 1 ).tolist ()
163+ if attention_mask .shape == input_ids .shape
164+ else [1 ] * input_ids .shape [0 ]
164165 )
165-
166- # Construct the appropriate attention mask.
167- if attention_mask is not None and self .config .self_attn_mask_type == "padding_causal" :
168- attention_mask = ~ attention_mask .to (bool )[:, None , None , :]
166+ past_key_values .pre_step (OrderedDict (zip (list (range (len (lengths ))), lengths )))
169167
170168 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
171169 if output_hidden_states :
172170 all_hidden_states = (* all_hidden_states , hidden_states )
173171
174172 hidden_states = decoder_layer (
175173 hidden_states ,
176- attention_mask = attention_mask ,
174+ attention_mask = None ,
177175 rotary_pos_emb = te_rope_emb ,
178176 inference_params = past_key_values ,
179- cu_seqlens_q = kwargs . get ( " cu_seq_lens_q" , None ) ,
180- cu_seqlens_kv = kwargs . get ( " cu_seq_lens_k" , None ) ,
181- max_seqlen_q = kwargs . get ( " max_length_q" , None ) ,
182- max_seqlen_kv = kwargs . get ( " max_length_k" , None ) ,
177+ cu_seqlens_q = cu_seq_lens_q ,
178+ cu_seqlens_kv = cu_seq_lens_k ,
179+ max_seqlen_q = max_length_q ,
180+ max_seqlen_kv = max_length_k ,
183181 )
184182
185183 hidden_states = self .norm (hidden_states )
186184
187- # add hidden states from the last decoder layer
185+ # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad
186+ # these with the same _pad_input call as below if we wanted them returned in BSHD format.
188187 if output_hidden_states :
189188 all_hidden_states = (* all_hidden_states , hidden_states )
190189
190+ if should_pack_inputs :
191+ # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output.
192+ hidden_states = _pad_input (hidden_states , indices , batch_size , max_length_q )
193+
191194 return BaseModelOutputWithPast (
192195 last_hidden_state = hidden_states ,
193196 past_key_values = past_key_values ,
@@ -225,7 +228,7 @@ def forward(
225228 labels : torch .Tensor | None = None ,
226229 use_cache : bool | None = None ,
227230 cache_position : torch .Tensor | None = None ,
228- only_keep_last_logits : bool = False ,
231+ logits_to_keep : int | torch . Tensor = 0 ,
229232 ** kwargs : Unpack [TransformersKwargs ],
230233 ) -> CausalLMOutputWithPast :
231234 """Forward pass for the NVLlamaForCausalLM model.
@@ -239,8 +242,8 @@ def forward(
239242 labels (torch.Tensor): The labels.
240243 use_cache (bool): Whether to use cache.
241244 cache_position (torch.Tensor): The cache position.
242- only_keep_last_logits (bool ): Whether to keep only the last logits, as a workaround for the fact that TE
243- doesn't support left-side padding with `padding_causal` attention masks .
245+ logits_to_keep (int | torch.Tensor ): Whether to keep only the last logits to reduce the memory footprint of
246+ the model during generation .
244247 **kwargs: Additional keyword arguments.
245248
246249 Returns:
@@ -258,26 +261,13 @@ def forward(
258261 )
259262
260263 hidden_states = outputs .last_hidden_state
264+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
265+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
261266
262- # TE doesn't support left-side padding with `padding_causal` attention masks, and InferenceParams doesn't
263- # support arbitrary attention masks (and the attention backend for arbitrary masks is the slower, unfused
264- # backend). To keep the inference interface consistent with HF's `GenerationMixin.generate` interface, we use a
265- # `only_keep_last_logits` flag to indicate that we should pick out and return only the last token's hidden state
266- # during pre-fill. This allows generation to work with right-side padding. Note, make sure that you decode the
267- # tokens with `skip_special_tokens=True` when using this flag, otherwise padding tokens will interrupt the
268- # generated text.
269- if (
270- only_keep_last_logits
271- and attention_mask is not None # Padded inputs
272- and hidden_states .shape [1 ] > 1 # We're in pre-fill mode
273- ):
274- seqlens = attention_mask .sum (dim = 1 ) # shape: [batch]
275- # For each batch idx, select hidden_states[idx, seqlens[idx]-1, :]
276- batch_indices = torch .arange (hidden_states .size (0 ), device = hidden_states .device )
277- selected_hidden_states = hidden_states [batch_indices , seqlens - 1 , :] # shape: [batch, hidden_dim]
278- hidden_states = selected_hidden_states .unsqueeze (1 ) # shape: [batch, 1, hidden_dim]
279-
280- logits = self .lm_head (hidden_states )
267+ if hidden_states .ndim == 3 :
268+ logits = self .lm_head (hidden_states [:, slice_indices , :])
269+ else : # With THD inputs, batch and sequence dimensions are collapsed in the first dimension.
270+ logits = self .lm_head (hidden_states [slice_indices , :])
281271
282272 loss = None
283273 if labels is not None :
@@ -306,3 +296,86 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
306296class NVLlamaForTokenClassification ( # noqa: D101
307297 transformers .modeling_layers .GenericForTokenClassification , NVLlamaPreTrainedModel
308298): ...
299+
300+
301+ torch ._dynamo .config .capture_scalar_outputs = True
302+
303+
304+ @torch .compile
305+ def _pad_input (hidden_states , indices , batch , seqlen ):
306+ """Convert a THD tensor to a BSHD equivalent tensor.
307+
308+ Adapted from huggingface/transformers/modeling_flash_attention_utils.py
309+
310+ Arguments:
311+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
312+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
313+ batch: int, batch size for the padded sequence.
314+ seqlen: int, maximum sequence length for the padded sequence.
315+
316+ Return:
317+ hidden_states: (batch, seqlen, ...)
318+ """
319+ dim = hidden_states .shape [1 :]
320+ output = torch .zeros ((batch * seqlen ), * dim , device = hidden_states .device , dtype = hidden_states .dtype )
321+ output [indices ] = hidden_states
322+ return output .view (batch , seqlen , * dim )
323+
324+
325+ @torch .compile
326+ def _unpad_input (hidden_states , attention_mask , unused_mask = None ):
327+ """Convert a BSHD tensor to a THD equivalent tensor.
328+
329+ Adapted from huggingface/transformers/modeling_flash_attention_utils.py
330+
331+ Arguments:
332+ hidden_states: (batch, seqlen, ...)
333+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
334+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
335+
336+ Return:
337+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
338+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
339+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
340+ max_seqlen_in_batch: int
341+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
342+ """
343+ batch_size = hidden_states .size (0 )
344+ seq_length = hidden_states .size (1 )
345+
346+ if attention_mask .shape [1 ] != seq_length : # Likely in generation mode with kv-caching
347+ return (
348+ hidden_states .squeeze (1 ), # hidden_states
349+ torch .arange (batch_size , dtype = torch .int64 , device = hidden_states .device ), # indices
350+ torch .arange (batch_size + 1 , dtype = torch .int32 , device = hidden_states .device ), # cu_seqlens
351+ 1 , # max_seqlen
352+ 1 , # seqused
353+ )
354+
355+ all_masks = (attention_mask + unused_mask ) if unused_mask is not None else attention_mask
356+ seqlens_in_batch = all_masks .sum (dim = - 1 , dtype = torch .int32 )
357+ used_seqlens_in_batch = attention_mask .sum (dim = - 1 , dtype = torch .int32 )
358+ indices = torch .nonzero (all_masks .flatten (), as_tuple = False ).flatten ()
359+ max_seqlen_in_batch = seqlens_in_batch .max ().item ()
360+ cu_seqlens = torch .nn .functional .pad (torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
361+
362+ return (
363+ hidden_states .reshape (- 1 , * hidden_states .shape [2 :])[indices ],
364+ indices ,
365+ cu_seqlens ,
366+ max_seqlen_in_batch ,
367+ used_seqlens_in_batch ,
368+ )
369+
370+
371+ class HFInferenceParams (InferenceParams ):
372+ """Extension of the InferenceParams class to support beam search."""
373+
374+ def reorder_cache (self , beam_idx : torch .LongTensor ):
375+ """Reorder the cache based on the beam indices."""
376+ if isinstance (self .cache_manager , PagedKVCacheManager ):
377+ raise NotImplementedError ("Beam search is not supported for paged cache manager." )
378+ for layer_number , (key_cache , value_cache ) in self .cache_manager .cache .items ():
379+ updated_key_cache = key_cache .index_select (0 , beam_idx )
380+ updated_value_cache = value_cache .index_select (0 , beam_idx )
381+ self .cache_manager .cache [layer_number ] = (updated_key_cache , updated_value_cache )
0 commit comments