Skip to content

Commit 3c1df50

Browse files
authored
Add BSHD inputs for llama3 model and ensure that inference works (#1303)
Better support for llama3 model batching, THD and BSHD inputs, and a demonstration of how to do inference with InferenceParams Signed-off-by: Peter St. John <[email protected]>
1 parent 98dda86 commit 3c1df50

File tree

4 files changed

+348
-79
lines changed

4 files changed

+348
-79
lines changed

bionemo-recipes/models/llama3/convert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def convert_llama_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
9090
Returns:
9191
nn.Module: The Transformer Engine model.
9292
"""
93-
hf_config = LlamaConfig(**model_te.config.to_dict(), **config_kwargs)
93+
# Filter out keys from model_te.config that are not valid LlamaConfig attributes
94+
te_config_dict = model_te.config.to_dict()
95+
valid_keys = set(LlamaConfig.__init__.__code__.co_varnames)
96+
filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys}
97+
hf_config = LlamaConfig(**filtered_config, **config_kwargs)
9498

9599
with torch.device("meta"):
96100
model_hf = LlamaForCausalLM(hf_config)

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 87 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,26 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from collections import OrderedDict
1617
from typing import Unpack
1718

1819
import torch
1920
import torch.nn as nn
2021
import transformer_engine.pytorch
2122
import transformers
23+
from transformer_engine.pytorch.attention import InferenceParams
24+
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
2225
from transformers import LlamaConfig, PreTrainedModel
2326
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24-
from transformers.modeling_rope_utils import dynamic_rope_update
2527
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
2628
from transformers.utils.generic import TransformersKwargs
2729

2830

29-
class NVLlamaConfig(LlamaConfig): ... # noqa: D101
31+
class NVLlamaConfig(LlamaConfig):
32+
"""NVLlama configuration."""
33+
34+
attn_input_format: str = "bshd"
35+
self_attn_mask_type: str = "padding_causal"
3036

3137

3238
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -62,7 +68,8 @@ def __init__(self, config: LlamaConfig):
6268
qkv_weight_interleaved=True,
6369
normalization="RMSNorm",
6470
activation="swiglu",
65-
attn_input_format="bshd",
71+
attn_input_format=config.attn_input_format,
72+
self_attn_mask_type=config.self_attn_mask_type,
6673
num_gqa_groups=config.num_key_value_heads,
6774
layer_number=layer_idx + 1,
6875
params_dtype=config.dtype,
@@ -71,7 +78,12 @@ def __init__(self, config: LlamaConfig):
7178
]
7279
)
7380
self.norm = transformer_engine.pytorch.RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=config.dtype)
74-
self.rotary_emb = NVLlamaRotaryEmbedding(config=config)
81+
82+
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
83+
# LlamaRotaryEmbedding.
84+
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
85+
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
86+
7587
self.gradient_checkpointing = False
7688

7789
# Initialize weights and apply final processing
@@ -82,9 +94,8 @@ def forward(
8294
input_ids: torch.Tensor | None = None,
8395
attention_mask: torch.Tensor | None = None,
8496
position_ids: torch.Tensor | None = None,
85-
past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None,
97+
past_key_values: InferenceParams | None = None,
8698
inputs_embeds: torch.Tensor | None = None,
87-
cache_position: torch.Tensor | None = None,
8899
use_cache: bool | None = None,
89100
**kwargs: Unpack[TransformersKwargs],
90101
) -> BaseModelOutputWithPast:
@@ -96,7 +107,6 @@ def forward(
96107
position_ids (torch.Tensor): The position ids.
97108
past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values.
98109
inputs_embeds (torch.Tensor): The inputs embeds.
99-
cache_position (torch.Tensor): The cache position.
100110
use_cache (bool): Whether to use cache.
101111
**kwargs: Additional keyword arguments.
102112
@@ -112,34 +122,64 @@ def forward(
112122
if inputs_embeds is None:
113123
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
114124

115-
if use_cache and past_key_values is None:
116-
past_key_values = transformers.cache_utils.DynamicCache(config=self.config)
125+
hidden_states = inputs_embeds
126+
if self.config.attn_input_format == "bshd":
127+
if past_key_values is not None:
128+
max_seq_len = past_key_values.max_sequence_length
129+
else:
130+
max_seq_len = hidden_states.shape[1]
131+
te_rope_emb = self.rotary_emb(max_seq_len=max_seq_len)
132+
elif self.config.attn_input_format == "thd":
133+
te_rope_emb = self.rotary_emb(max_seq_len=kwargs["cu_seq_lens_q"][-1])
134+
135+
has_thd_input = [
136+
x is not None
137+
for x in [
138+
kwargs.get("cu_seq_lens_q", None),
139+
kwargs.get("cu_seq_lens_k", None),
140+
kwargs.get("max_length_q", None),
141+
kwargs.get("max_length_k", None),
142+
]
143+
]
117144

118-
if cache_position is None:
119-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
120-
cache_position: torch.Tensor = torch.arange(
121-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
145+
if isinstance(past_key_values, InferenceParams):
146+
# lengths = attention_mask.sum(dim=1) if attention_mask is not None else torch.tensor([0])
147+
lengths = input_ids.ne(0).sum(dim=1) if input_ids is not None else torch.tensor([0])
148+
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))
149+
150+
if self.config.attn_input_format == "thd":
151+
if not all(has_thd_input):
152+
raise ValueError(
153+
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
154+
)
155+
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
156+
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
122157
)
158+
hidden_states = hidden_states.squeeze(0)
159+
attention_mask = None
123160

124-
if position_ids is None:
125-
position_ids = cache_position.unsqueeze(0)
161+
elif self.config.attn_input_format == "bshd" and any(has_thd_input):
162+
raise ValueError(
163+
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
164+
)
126165

127-
hidden_states = inputs_embeds
128-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
166+
# Construct the appropriate attention mask.
167+
if attention_mask is not None and self.config.self_attn_mask_type == "padding_causal":
168+
attention_mask = ~attention_mask.to(bool)[:, None, None, :]
129169

130170
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
131171
if output_hidden_states:
132172
all_hidden_states = (*all_hidden_states, hidden_states)
133173

134174
hidden_states = decoder_layer(
135175
hidden_states,
136-
attention_mask=None,
137-
self_attn_mask_type="causal",
138-
rotary_pos_emb=position_embeddings,
139-
# position_ids=position_ids,
140-
# past_key_values=past_key_values,
141-
# cache_position=cache_position,
142-
# **kwargs,
176+
attention_mask=attention_mask,
177+
rotary_pos_emb=te_rope_emb,
178+
inference_params=past_key_values,
179+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
180+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
181+
max_seqlen_q=kwargs.get("max_length_q", None),
182+
max_seqlen_kv=kwargs.get("max_length_k", None),
143183
)
144184

145185
hidden_states = self.norm(hidden_states)
@@ -185,7 +225,7 @@ def forward(
185225
labels: torch.Tensor | None = None,
186226
use_cache: bool | None = None,
187227
cache_position: torch.Tensor | None = None,
188-
logits_to_keep: int | torch.Tensor = 0,
228+
only_keep_last_logits: bool = False,
189229
**kwargs: Unpack[TransformersKwargs],
190230
) -> CausalLMOutputWithPast:
191231
"""Forward pass for the NVLlamaForCausalLM model.
@@ -199,7 +239,8 @@ def forward(
199239
labels (torch.Tensor): The labels.
200240
use_cache (bool): Whether to use cache.
201241
cache_position (torch.Tensor): The cache position.
202-
logits_to_keep (int | torch.Tensor): The logits to keep.
242+
only_keep_last_logits (bool): Whether to keep only the last logits, as a workaround for the fact that TE
243+
doesn't support left-side padding with `padding_causal` attention masks.
203244
**kwargs: Additional keyword arguments.
204245
205246
Returns:
@@ -217,9 +258,26 @@ def forward(
217258
)
218259

219260
hidden_states = outputs.last_hidden_state
220-
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
221-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
222-
logits = self.lm_head(hidden_states[:, slice_indices, :])
261+
262+
# TE doesn't support left-side padding with `padding_causal` attention masks, and InferenceParams doesn't
263+
# support arbitrary attention masks (and the attention backend for arbitrary masks is the slower, unfused
264+
# backend). To keep the inference interface consistent with HF's `GenerationMixin.generate` interface, we use a
265+
# `only_keep_last_logits` flag to indicate that we should pick out and return only the last token's hidden state
266+
# during pre-fill. This allows generation to work with right-side padding. Note, make sure that you decode the
267+
# tokens with `skip_special_tokens=True` when using this flag, otherwise padding tokens will interrupt the
268+
# generated text.
269+
if (
270+
only_keep_last_logits
271+
and attention_mask is not None # Padded inputs
272+
and hidden_states.shape[1] > 1 # We're in pre-fill mode
273+
):
274+
seqlens = attention_mask.sum(dim=1) # shape: [batch]
275+
# For each batch idx, select hidden_states[idx, seqlens[idx]-1, :]
276+
batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
277+
selected_hidden_states = hidden_states[batch_indices, seqlens - 1, :] # shape: [batch, hidden_dim]
278+
hidden_states = selected_hidden_states.unsqueeze(1) # shape: [batch, 1, hidden_dim]
279+
280+
logits = self.lm_head(hidden_states)
223281

224282
loss = None
225283
if labels is not None:
@@ -248,25 +306,3 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
248306
class NVLlamaForTokenClassification( # noqa: D101
249307
transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel
250308
): ...
251-
252-
253-
class NVLlamaRotaryEmbedding(LlamaRotaryEmbedding):
254-
"""Slight modification of the LlamaRotaryEmbedding for use with Transformer Engine."""
255-
256-
@torch.no_grad()
257-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
258-
def forward(self, x, position_ids): # pyright: ignore[reportIncompatibleMethodOverride]
259-
"""Forward pass for the NVLlamaRotaryEmbedding.
260-
261-
Unlike the original LlamaRotaryEmbedding, this implementation returns the frequency tensor (upstream of the
262-
cosine and sine transforms), reshaped in a way that is compatible with TransformerEngine's fused RoPE.
263-
"""
264-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
265-
position_ids_expanded = position_ids[:, None, :].float()
266-
267-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
268-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
269-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
270-
emb = torch.cat((freqs, freqs), dim=-1)
271-
272-
return emb.transpose(0, 1).unsqueeze(1)

bionemo-recipes/models/llama3/tests/test_convert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from transformers.models.llama.modeling_llama import LlamaForCausalLM
2121

2222
from convert import convert_llama_hf_to_te, convert_llama_te_to_hf
23-
from modeling_llama_te import NVLlamaForCausalLM
23+
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
2424

2525

2626
def test_convert_llama_hf_to_te_roundtrip(caplog):
@@ -71,7 +71,9 @@ def test_convert_hf_to_te_with_bf16():
7171

7272

7373
def test_convert_te_to_hf_with_bf16():
74-
config = AutoConfig.from_pretrained("nvidia/Llama-3.1-8B-Instruct-FP8", dtype=torch.bfloat16, num_hidden_layers=2)
74+
config = NVLlamaConfig.from_pretrained(
75+
"nvidia/Llama-3.1-8B-Instruct-FP8", dtype=torch.bfloat16, num_hidden_layers=2
76+
)
7577
model_te = NVLlamaForCausalLM(config)
7678
model_te.to(dtype=torch.float32) # I think the original llama3 model doesn't initialize in bf16.
7779
convert_llama_te_to_hf(model_te)

0 commit comments

Comments
 (0)