Skip to content

Commit fe6936c

Browse files
authored
always use THD inputs for llama3 model (#1322)
Because left-side padding isn't supported by TE, we can always ensure that our TE inputs are THD-packed to support the standard HF generation pipeline. --------- Signed-off-by: Peter St. John <[email protected]>
1 parent 3c1df50 commit fe6936c

File tree

3 files changed

+199
-130
lines changed

3 files changed

+199
-130
lines changed

bionemo-recipes/models/llama3/convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import torch
17-
import torch.nn as nn
1817
from transformers import LlamaConfig, LlamaForCausalLM
1918

2019
import state
@@ -35,7 +34,7 @@
3534
reverse_mapping = {v: k for k, v in mapping.items()}
3635

3736

38-
def convert_llama_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
37+
def convert_llama_hf_to_te(model_hf: LlamaForCausalLM, **config_kwargs) -> NVLlamaForCausalLM:
3938
"""Convert a Hugging Face model to a Transformer Engine model.
4039
4140
Args:
@@ -80,7 +79,7 @@ def convert_llama_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
8079
return output_model
8180

8281

83-
def convert_llama_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
82+
def convert_llama_te_to_hf(model_te: NVLlamaForCausalLM, **config_kwargs) -> LlamaForCausalLM:
8483
"""Convert a Hugging Face model to a Transformer Engine model.
8584
8685
Args:

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 142 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import transformer_engine.pytorch
2222
import transformers
2323
from transformer_engine.pytorch.attention import InferenceParams
24+
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
2425
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
2526
from transformers import LlamaConfig, PreTrainedModel
2627
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -31,9 +32,6 @@
3132
class NVLlamaConfig(LlamaConfig):
3233
"""NVLlama configuration."""
3334

34-
attn_input_format: str = "bshd"
35-
self_attn_mask_type: str = "padding_causal"
36-
3735

3836
class NVLlamaPreTrainedModel(PreTrainedModel):
3937
"""Base class for NVLlama models."""
@@ -68,8 +66,8 @@ def __init__(self, config: LlamaConfig):
6866
qkv_weight_interleaved=True,
6967
normalization="RMSNorm",
7068
activation="swiglu",
71-
attn_input_format=config.attn_input_format,
72-
self_attn_mask_type=config.self_attn_mask_type,
69+
attn_input_format="thd",
70+
self_attn_mask_type="padding_causal",
7371
num_gqa_groups=config.num_key_value_heads,
7472
layer_number=layer_idx + 1,
7573
params_dtype=config.dtype,
@@ -123,71 +121,76 @@ def forward(
123121
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
124122

125123
hidden_states = inputs_embeds
126-
if self.config.attn_input_format == "bshd":
127-
if past_key_values is not None:
128-
max_seq_len = past_key_values.max_sequence_length
129-
else:
130-
max_seq_len = hidden_states.shape[1]
131-
te_rope_emb = self.rotary_emb(max_seq_len=max_seq_len)
132-
elif self.config.attn_input_format == "thd":
133-
te_rope_emb = self.rotary_emb(max_seq_len=kwargs["cu_seq_lens_q"][-1])
134-
135-
has_thd_input = [
136-
x is not None
137-
for x in [
138-
kwargs.get("cu_seq_lens_q", None),
139-
kwargs.get("cu_seq_lens_k", None),
140-
kwargs.get("max_length_q", None),
141-
kwargs.get("max_length_k", None),
142-
]
143-
]
144124

145-
if isinstance(past_key_values, InferenceParams):
146-
# lengths = attention_mask.sum(dim=1) if attention_mask is not None else torch.tensor([0])
147-
lengths = input_ids.ne(0).sum(dim=1) if input_ids is not None else torch.tensor([0])
148-
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))
149-
150-
if self.config.attn_input_format == "thd":
151-
if not all(has_thd_input):
152-
raise ValueError(
153-
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
154-
)
125+
has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]]
126+
should_pack_inputs = not any(has_thd_input)
127+
128+
# This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
129+
# attention backend.
130+
if should_pack_inputs:
131+
# Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
132+
# to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
133+
# entire transformer stack run in THD mode.
134+
assert attention_mask is not None, "Attention mask is required when using BSHD inputs."
135+
batch_size = hidden_states.size(0)
136+
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
137+
cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
138+
max_length_q = max_length_k = max_seqlen
139+
140+
else:
141+
# Here, we're providing THD-style inputs, so we can just grab the kwargs.
155142
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
156143
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
157144
)
158145
hidden_states = hidden_states.squeeze(0)
159-
attention_mask = None
146+
cu_seq_lens_q = kwargs["cu_seq_lens_q"]
147+
cu_seq_lens_k = kwargs["cu_seq_lens_k"]
148+
max_length_q = kwargs["max_length_q"]
149+
max_length_k = kwargs["max_length_k"]
150+
151+
# If we're using kv-caching, we can't trust the max_length_q value as the true max length for rotary
152+
# embeddings, since this will be 1 in generation. Instead we can take the max sequence length from the past
153+
# key values object.
154+
te_rope_emb = self.rotary_emb(
155+
max_seq_len=max_length_q if past_key_values is None else past_key_values.max_ctx_len
156+
)
160157

161-
elif self.config.attn_input_format == "bshd" and any(has_thd_input):
162-
raise ValueError(
163-
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
158+
if isinstance(past_key_values, InferenceParams):
159+
# In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
160+
# compute the lengths of each sequence in the batch.
161+
lengths = (
162+
attention_mask.sum(dim=1).tolist()
163+
if attention_mask.shape == input_ids.shape
164+
else [1] * input_ids.shape[0]
164165
)
165-
166-
# Construct the appropriate attention mask.
167-
if attention_mask is not None and self.config.self_attn_mask_type == "padding_causal":
168-
attention_mask = ~attention_mask.to(bool)[:, None, None, :]
166+
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))
169167

170168
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
171169
if output_hidden_states:
172170
all_hidden_states = (*all_hidden_states, hidden_states)
173171

174172
hidden_states = decoder_layer(
175173
hidden_states,
176-
attention_mask=attention_mask,
174+
attention_mask=None,
177175
rotary_pos_emb=te_rope_emb,
178176
inference_params=past_key_values,
179-
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
180-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
181-
max_seqlen_q=kwargs.get("max_length_q", None),
182-
max_seqlen_kv=kwargs.get("max_length_k", None),
177+
cu_seqlens_q=cu_seq_lens_q,
178+
cu_seqlens_kv=cu_seq_lens_k,
179+
max_seqlen_q=max_length_q,
180+
max_seqlen_kv=max_length_k,
183181
)
184182

185183
hidden_states = self.norm(hidden_states)
186184

187-
# add hidden states from the last decoder layer
185+
# add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad
186+
# these with the same _pad_input call as below if we wanted them returned in BSHD format.
188187
if output_hidden_states:
189188
all_hidden_states = (*all_hidden_states, hidden_states)
190189

190+
if should_pack_inputs:
191+
# If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output.
192+
hidden_states = _pad_input(hidden_states, indices, batch_size, max_length_q)
193+
191194
return BaseModelOutputWithPast(
192195
last_hidden_state=hidden_states,
193196
past_key_values=past_key_values,
@@ -225,7 +228,7 @@ def forward(
225228
labels: torch.Tensor | None = None,
226229
use_cache: bool | None = None,
227230
cache_position: torch.Tensor | None = None,
228-
only_keep_last_logits: bool = False,
231+
logits_to_keep: int | torch.Tensor = 0,
229232
**kwargs: Unpack[TransformersKwargs],
230233
) -> CausalLMOutputWithPast:
231234
"""Forward pass for the NVLlamaForCausalLM model.
@@ -239,8 +242,8 @@ def forward(
239242
labels (torch.Tensor): The labels.
240243
use_cache (bool): Whether to use cache.
241244
cache_position (torch.Tensor): The cache position.
242-
only_keep_last_logits (bool): Whether to keep only the last logits, as a workaround for the fact that TE
243-
doesn't support left-side padding with `padding_causal` attention masks.
245+
logits_to_keep (int | torch.Tensor): Whether to keep only the last logits to reduce the memory footprint of
246+
the model during generation.
244247
**kwargs: Additional keyword arguments.
245248
246249
Returns:
@@ -258,26 +261,13 @@ def forward(
258261
)
259262

260263
hidden_states = outputs.last_hidden_state
264+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
265+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
261266

262-
# TE doesn't support left-side padding with `padding_causal` attention masks, and InferenceParams doesn't
263-
# support arbitrary attention masks (and the attention backend for arbitrary masks is the slower, unfused
264-
# backend). To keep the inference interface consistent with HF's `GenerationMixin.generate` interface, we use a
265-
# `only_keep_last_logits` flag to indicate that we should pick out and return only the last token's hidden state
266-
# during pre-fill. This allows generation to work with right-side padding. Note, make sure that you decode the
267-
# tokens with `skip_special_tokens=True` when using this flag, otherwise padding tokens will interrupt the
268-
# generated text.
269-
if (
270-
only_keep_last_logits
271-
and attention_mask is not None # Padded inputs
272-
and hidden_states.shape[1] > 1 # We're in pre-fill mode
273-
):
274-
seqlens = attention_mask.sum(dim=1) # shape: [batch]
275-
# For each batch idx, select hidden_states[idx, seqlens[idx]-1, :]
276-
batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
277-
selected_hidden_states = hidden_states[batch_indices, seqlens - 1, :] # shape: [batch, hidden_dim]
278-
hidden_states = selected_hidden_states.unsqueeze(1) # shape: [batch, 1, hidden_dim]
279-
280-
logits = self.lm_head(hidden_states)
267+
if hidden_states.ndim == 3:
268+
logits = self.lm_head(hidden_states[:, slice_indices, :])
269+
else: # With THD inputs, batch and sequence dimensions are collapsed in the first dimension.
270+
logits = self.lm_head(hidden_states[slice_indices, :])
281271

282272
loss = None
283273
if labels is not None:
@@ -306,3 +296,86 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
306296
class NVLlamaForTokenClassification( # noqa: D101
307297
transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel
308298
): ...
299+
300+
301+
torch._dynamo.config.capture_scalar_outputs = True
302+
303+
304+
@torch.compile
305+
def _pad_input(hidden_states, indices, batch, seqlen):
306+
"""Convert a THD tensor to a BSHD equivalent tensor.
307+
308+
Adapted from huggingface/transformers/modeling_flash_attention_utils.py
309+
310+
Arguments:
311+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
312+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
313+
batch: int, batch size for the padded sequence.
314+
seqlen: int, maximum sequence length for the padded sequence.
315+
316+
Return:
317+
hidden_states: (batch, seqlen, ...)
318+
"""
319+
dim = hidden_states.shape[1:]
320+
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
321+
output[indices] = hidden_states
322+
return output.view(batch, seqlen, *dim)
323+
324+
325+
@torch.compile
326+
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
327+
"""Convert a BSHD tensor to a THD equivalent tensor.
328+
329+
Adapted from huggingface/transformers/modeling_flash_attention_utils.py
330+
331+
Arguments:
332+
hidden_states: (batch, seqlen, ...)
333+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
334+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
335+
336+
Return:
337+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
338+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
339+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
340+
max_seqlen_in_batch: int
341+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
342+
"""
343+
batch_size = hidden_states.size(0)
344+
seq_length = hidden_states.size(1)
345+
346+
if attention_mask.shape[1] != seq_length: # Likely in generation mode with kv-caching
347+
return (
348+
hidden_states.squeeze(1), # hidden_states
349+
torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), # indices
350+
torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), # cu_seqlens
351+
1, # max_seqlen
352+
1, # seqused
353+
)
354+
355+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
356+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
357+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
358+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
359+
max_seqlen_in_batch = seqlens_in_batch.max().item()
360+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
361+
362+
return (
363+
hidden_states.reshape(-1, *hidden_states.shape[2:])[indices],
364+
indices,
365+
cu_seqlens,
366+
max_seqlen_in_batch,
367+
used_seqlens_in_batch,
368+
)
369+
370+
371+
class HFInferenceParams(InferenceParams):
372+
"""Extension of the InferenceParams class to support beam search."""
373+
374+
def reorder_cache(self, beam_idx: torch.LongTensor):
375+
"""Reorder the cache based on the beam indices."""
376+
if isinstance(self.cache_manager, PagedKVCacheManager):
377+
raise NotImplementedError("Beam search is not supported for paged cache manager.")
378+
for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items():
379+
updated_key_cache = key_cache.index_select(0, beam_idx)
380+
updated_value_cache = value_cache.index_select(0, beam_idx)
381+
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)

0 commit comments

Comments
 (0)