Skip to content

Commit 98dda86

Browse files
authored
Make the eden model inherit from llama3.1 (#1316)
### Description Since the eden config inherits from llama3 rather than llama3.1 the default nemo conversion classes do not save the `rope_scaling` settings: ``` (Pdb) config LlamaConfig { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": null, "dtype": "bfloat16", "eos_token_id": 0, "head_dim": 128, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "mlp_bias": false, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000, "tie_word_embeddings": false, "transformers_version": "4.57.1", "use_cache": true, "vocab_size": 512 } ``` which should happen in NeMo with an isinstance match: ``` # For Llama 3.1 and Llama 3.2, rope_scaling is used and thus needed to parsed to the config if isinstance(source, Llama31Config): rope_scaling = { 'factor': source.scale_factor, 'low_freq_factor': source.low_freq_factor, 'high_freq_factor': source.high_freq_factor, 'original_max_position_embeddings': source.old_context_len, 'rope_type': 'llama3', } ``` This change modifies the inheritance structure so that this matches with the intended llama3.1 config that has the inverse frequency override. #### Usage ```bash BIONEMO_DATA_SOURCE=pbss py.test \ sub-packages/bionemo-evo2/tests/bionemo/evo2/models/test_llama.py \ sub-packages/bionemo-evo2/tests/bionemo/evo2/utils/checkpoint/test_eden_llama_roundtrip.py ``` Returns: ``` -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================== slowest durations ========================================================================================================== 213.80s call sub-packages/bionemo-evo2/tests/bionemo/evo2/utils/checkpoint/test_eden_llama_roundtrip.py::test_eden_llama_roundtrip 74.26s call sub-packages/bionemo-evo2/tests/bionemo/evo2/models/test_llama.py::test_checkpoint_conversion 42.58s call sub-packages/bionemo-evo2/tests/bionemo/evo2/models/test_llama.py::test_golden_values_llama (6 durations < 30s hidden. Use -vv to show these durations.) ============================================================================================= 3 passed, 76 warnings in 343.93s (0:05:43) ============================================================================================== Skipping execution of on_app_end because OneLogger is not enabled. sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully --------- Signed-off-by: John St John <[email protected]>
1 parent 3f66690 commit 98dda86

File tree

5 files changed

+183
-43
lines changed

5 files changed

+183
-43
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
- tag: eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814:1.0
2+
ngc: null
3+
ngc_registry: resource
4+
pbss: "s3://bionemo-ci/test_data/evo2/eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814.tar.gz"
5+
sha256: 7e13dde3ff1c2be113dcbd73de812b29b229cba700b7c981eb048e16dbb6b0cb # pragma: allowlist secret
6+
owner: John St John <[email protected]>
7+
description: >
8+
Test data for Evo2 llama inference.
9+
10+
- tag: 7B-8k-og2:1.0
11+
ngc: null
12+
ngc_registry: model
13+
pbss: "s3://bionemo-ci/models/eden_llama_og2_step_182313.tar.gz"
14+
sha256: 80a9dae5155f10c9c48e913be55900f51f231fab1252464938867c7511035010 # pragma: allowlist secret
15+
owner: John St John <[email protected]>
16+
description: >
17+
7b llama 3.1 checkpoint trained on the open genome 2 metagenome subset data for approximately 250 billion tokens.

sub-packages/bionemo-evo2/src/bionemo/evo2/models/llama.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919

2020
import torch
2121
from nemo.collections import llm
22-
from nemo.collections.llm.gpt.model.llama import HFLlamaImporter, LlamaModel, apply_rope_scaling
22+
from nemo.collections.llm.gpt.model.llama import HFLlamaImporter, LlamaModel
2323
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
2424
from nemo.lightning import io
2525
from nemo.lightning.pytorch.utils import dtype_from_hf
2626

2727

2828
@dataclass
29-
class EdenConfig(llm.Llama3Config8B):
30-
"""Eden-flavoured Llama-3.1 ~8B (keeps all Eden behaviors)."""
29+
class EdenConfig(llm.Llama31Config8B):
30+
"""Eden-flavoured Llama-3.1 ~8B (keeps all Eden behaviors). Inherits from the llama 3.1 config for proper handling of RoPE when converting checkpoints."""
3131

3232
rotary_base: int = 500_000
3333
seq_length: int = 8192
@@ -43,22 +43,6 @@ class EdenConfig(llm.Llama3Config8B):
4343
init_method_std: float = 0.02
4444
embedding_init_method_std: Optional[float] = None
4545

46-
def configure_model(self, *args, **kwargs):
47-
"""Configure and instantiate a Megatron Core Llama 3.1 model.
48-
49-
Extends the base configuration with Llama 3.1 specific RoPE scaling.
50-
"""
51-
model = super(EdenConfig, self).configure_model(*args, **kwargs)
52-
# Apply rope scaling for Llama3.1 model
53-
model.rotary_pos_emb.inv_freq = apply_rope_scaling(
54-
model.rotary_pos_emb.inv_freq,
55-
factor=self.scale_factor,
56-
low_freq_factor=self.low_freq_factor,
57-
high_freq_factor=self.high_freq_factor,
58-
old_context_len=self.old_context_len,
59-
)
60-
return model
61-
6246

6347
@dataclass
6448
class Eden11BConfig(EdenConfig):
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import os
18+
import subprocess
19+
20+
import pytest
21+
import torch
22+
from transformers import AutoModelForCausalLM
23+
24+
from bionemo.core.data.load import load
25+
26+
27+
@pytest.fixture(scope="module")
28+
def eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814():
29+
"""Test data for Evo2 llama inference.
30+
31+
Returns:
32+
tree
33+
.
34+
├── per_layer_activations
35+
│ └── activations_rank000_dl00_batch000000.pt
36+
├── predictions__rank_0__dp_rank_0.pt
37+
├── ribosomal_rrna_highly_conserved_PMC4140814.fasta
38+
└── seq_idx_map.json
39+
40+
1 directory, 4 files
41+
"""
42+
return load("evo2_llama/eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814:1.0")
43+
44+
45+
@pytest.fixture(scope="module")
46+
def llama_7b_8k_og2():
47+
return load("evo2_llama/7B-8k-og2:1.0")
48+
49+
50+
@pytest.mark.skipif(os.environ.get("BIONEMO_DATA_SOURCE") != "pbss", reason="Test data is not available on NGC")
51+
def test_golden_values_llama(
52+
tmp_path, eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814, llama_7b_8k_og2
53+
):
54+
fasta_path = (
55+
eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814
56+
/ "ribosomal_rrna_highly_conserved_PMC4140814.fasta"
57+
)
58+
gold_values_path = (
59+
eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814 / "predictions__rank_0__dp_rank_0.pt"
60+
)
61+
output_dir = tmp_path / "predictions_llama"
62+
prediction_cmd = (
63+
f"predict_evo2 --fasta {fasta_path} --ckpt-dir {llama_7b_8k_og2} --output-dir {output_dir} --model-size 7B"
64+
)
65+
subprocess.run(prediction_cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
66+
predictions = torch.load(output_dir / "predictions__rank_0__dp_rank_0.pt", weights_only=True)
67+
gold_values = torch.load(gold_values_path, weights_only=True)
68+
assert predictions["token_logits"].shape == gold_values["token_logits"].shape
69+
torch.testing.assert_close(predictions["token_logits"], gold_values["token_logits"], atol=0.5, rtol=0)
70+
71+
72+
@pytest.mark.skipif(os.environ.get("BIONEMO_DATA_SOURCE") != "pbss", reason="Test data is not available on NGC")
73+
def test_checkpoint_conversion(
74+
tmp_path, eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814, llama_7b_8k_og2
75+
):
76+
target_dir = tmp_path / "llama_7b_8k_og2"
77+
convert_cmd = f"evo2_nemo2_to_hf --model-type llama --model-path {llama_7b_8k_og2} --output-dir {target_dir}"
78+
subprocess.run(convert_cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
79+
assert target_dir.exists()
80+
assert target_dir.is_dir()
81+
hf_model = AutoModelForCausalLM.from_pretrained(
82+
target_dir,
83+
torch_dtype=torch.bfloat16,
84+
local_files_only=True, # Force loading from local path, not HF Hub
85+
use_cache=False, # Disable use_cache to get the correct forward pass outside of generate.
86+
).eval()
87+
# # Add hooks to capture inputs/outputs for forward pass
88+
# activations = {}
89+
# def capture_hook(name):
90+
# def hook(module, input, output):
91+
# # if not isinstance(input, torch.Tensor):
92+
# # input = None
93+
# # if not isinstance(output, torch.Tensor):
94+
# # output = None
95+
# activations[name] = {
96+
# 'input': input,
97+
# 'output': output
98+
# }
99+
# return hook
100+
# # Register hooks on key layers
101+
# for name, module in hf_model.named_modules():
102+
# module.register_forward_hook(capture_hook(name))
103+
fasta_path = (
104+
eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814
105+
/ "ribosomal_rrna_highly_conserved_PMC4140814.fasta"
106+
)
107+
with open(fasta_path, "r") as f:
108+
fasta_seq = f.readlines()[1].strip()
109+
input_ids = torch.tensor([ord(c) for c in fasta_seq]).unsqueeze(0) # add batch dimension
110+
with torch.no_grad():
111+
outputs = hf_model(input_ids)
112+
gold_values_path = (
113+
eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814 / "predictions__rank_0__dp_rank_0.pt"
114+
)
115+
gold_values = torch.load(gold_values_path, weights_only=True)
116+
assert outputs.logits.shape == gold_values["token_logits"].shape
117+
torch.testing.assert_close(outputs.logits, gold_values["token_logits"].to(dtype=torch.bfloat16), atol=0.5, rtol=0)

sub-packages/bionemo-evo2/tests/bionemo/evo2/utils/checkpoint/test_eden_llama_roundtrip.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
# limitations under the License.
1515

1616
import json
17+
import os
1718
from pathlib import Path
1819

1920
import pytest
2021
import torch
21-
from lightning.fabric.plugins.environments.lightning import find_free_network_port
2222
from nemo.collections.llm.gpt.model.llama import HFLlamaExporter
2323

24+
from bionemo.core.data.load import load
2425
from bionemo.evo2.models.llama import HFEdenLlamaImporter
2526
from bionemo.llm.lightning import batch_collator
2627
from bionemo.testing.subprocess_utils import run_command_in_subprocess
@@ -30,62 +31,69 @@
3031

3132

3233
@pytest.fixture(scope="module")
33-
def checkpoint_eden_path() -> Path:
34-
"""
35-
mkdir -p $REPO_PATH/tmp_checkpoints
36-
scp -r jstjohn@computelab-sc-01:/home/jstjohn/scratch/checkpoints/eden_llama_og2_step_182313 $REPO_PATH/tmp_checkpoints/
34+
def eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814():
35+
"""Test data for Evo2 llama inference.
36+
37+
Returns:
38+
tree
39+
.
40+
├── per_layer_activations
41+
│ └── activations_rank000_dl00_batch000000.pt
42+
├── predictions__rank_0__dp_rank_0.pt
43+
├── ribosomal_rrna_highly_conserved_PMC4140814.fasta
44+
└── seq_idx_map.json
45+
46+
1 directory, 4 files
3747
"""
38-
return REPO_PATH / "tmp_checkpoints" / "eden_llama_og2_step_182313"
48+
return load("evo2_llama/eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814:1.0")
3949

4050

4151
@pytest.fixture(scope="module")
42-
def metagenome_fasta_path() -> Path:
43-
"""
44-
mkdir -p $REPO_PATH/tmp_data
45-
scp -r jstjohn@computelab-sc-01:/home/jstjohn/scratch/experiments/evo2_activations/ckpt_lm_loss_evals/lm_loss_work/evo2_metagenomics_test_only_sl8192_sd42.fasta $REPO_PATH/tmp_data/
46-
"""
47-
return REPO_PATH / "tmp_data" / "evo2_metagenomics_test_only_sl8192_sd42.fasta"
52+
def llama_7b_8k_og2():
53+
return load("evo2_llama/7B-8k-og2:1.0")
4854

4955

5056
def predict_metagenome(
5157
model_checkpoint_path: Path, metagenome_fasta_path: Path, output_path: Path
5258
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
53-
port = find_free_network_port()
54-
cmd = f"""NCCL_P2P_DISABLE=1 torchrun --nproc_per_node=2 --master-port={port} --no-python \
55-
predict_evo2 \
59+
cmd = f"""predict_evo2 \
5660
--eden-tokenizer \
57-
--devices=2 \
5861
--model-size 7B \
59-
--tensor-parallel-size=2 \
6062
--fasta {metagenome_fasta_path} \
6163
--ckpt-dir {model_checkpoint_path} \
6264
--output-log-prob-seqs \
6365
--log-prob-collapse-option per_token \
6466
--output-dir {output_path}"""
65-
run_command_in_subprocess(cmd, str(REPO_PATH))
67+
run_command_in_subprocess(cmd, os.getcwd())
6668
with open(output_path / "seq_idx_map.json", "r") as jsonf:
6769
fasta_to_index = json.load(jsonf)
6870
preds_list = [torch.load(f) for f in output_path.glob("*.pt")]
6971
all_pt_data = batch_collator([item for item in preds_list if item is not None])
7072
return all_pt_data, fasta_to_index # type: ignore
7173

7274

75+
@pytest.mark.skipif(os.environ.get("BIONEMO_DATA_SOURCE") != "pbss", reason="Test data is not available on NGC")
7376
@pytest.mark.slow
74-
def test_eden_llama_roundtrip(tmp_path, checkpoint_eden_path: Path, metagenome_fasta_path: Path):
77+
def test_eden_llama_roundtrip(
78+
tmp_path, llama_7b_8k_og2: Path, eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814: Path
79+
):
7580
"""Test that converting NeMo -> HF -> NeMo produces the same model."""
76-
if not checkpoint_eden_path.exists() or not metagenome_fasta_path.exists():
77-
pytest.skip("Skipping test, first download the checkpoint and the metagenome fasta.")
81+
fasta_path = (
82+
eden_llama_og2_step_182313_on_evo2_rrna_highly_conserved_PMC4140814
83+
/ "ribosomal_rrna_highly_conserved_PMC4140814.fasta"
84+
)
85+
assert llama_7b_8k_og2.exists() and fasta_path.exists()
7886

79-
exporter = HFLlamaExporter(checkpoint_eden_path)
87+
exporter = HFLlamaExporter(llama_7b_8k_og2)
8088
hf_path = tmp_path / "hf_checkpoint"
8189
exporter.apply(hf_path)
8290
importer = HFEdenLlamaImporter(hf_path)
8391
importer.apply(tmp_path / "nemo_checkpoint")
8492
original_predictions, original_fasta_to_index = predict_metagenome(
85-
checkpoint_eden_path, metagenome_fasta_path, tmp_path / "original_predictions"
93+
llama_7b_8k_og2, fasta_path, tmp_path / "original_predictions"
8694
)
8795
new_predictions, new_fasta_to_index = predict_metagenome(
88-
tmp_path / "nemo_checkpoint", metagenome_fasta_path, tmp_path / "new_predictions"
96+
tmp_path / "nemo_checkpoint", fasta_path, tmp_path / "new_predictions"
8997
)
9098
assert original_fasta_to_index == new_fasta_to_index, "Fasta to index mapping is not the same, need better logic."
9199
for key in ["seq_idx", "log_probs_seqs", "loss_mask"]:

0 commit comments

Comments
 (0)