diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index e75c02cfa1..aec6cf9ae1 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -24,7 +24,7 @@ # TODO add back support for slurm resilience. # import nvidia_resiliency_ext.ptl_resiliency as res_module import torch -from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary +from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.enums import Fp8Recipe from megatron.core.optimizer import OptimizerConfig @@ -53,7 +53,7 @@ from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings from bionemo.evo2.models.peft import Evo2LoRA from bionemo.evo2.run.utils import infer_model_type, patch_eden_tokenizer -from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime +from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime, _FirstBatchCudaSync from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings from bionemo.evo2.utils.logging.callbacks import TEVCallback from bionemo.llm.utils.datamodule_utils import infer_global_batch_size @@ -853,27 +853,6 @@ def train(args: argparse.Namespace) -> nl.Trainer: TEVCallback(), ] - # First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition - # See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details. - class _FirstBatchCudaSync(Callback): - def __init__(self): - self._done = False - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - if not self._done and torch.cuda.is_available(): - torch.cuda.synchronize() - - def on_after_backward(self, trainer, pl_module): - if not self._done and torch.cuda.is_available(): - torch.cuda.synchronize() - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if not self._done and torch.cuda.is_available(): - torch.cuda.synchronize() - # Unset blocking for subsequent batches - os.environ.pop("CUDA_LAUNCH_BLOCKING", None) - self._done = True - callbacks.append(_FirstBatchCudaSync()) if args.garbage_collect_at_inference: @@ -1103,15 +1082,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): enable_checkpointing=args.create_checkpoint_callback, ) - # Logger setup - nemo_logger.setup( - trainer, - resume_if_exists=True, - ) - - if auto_resume is not None: - auto_resume.setup(trainer, model) - # Optimizer and scheduler setup opt_config = OptimizerConfig( optimizer="adam", @@ -1139,12 +1109,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): opt = MegatronOptimizerModule( opt_config, sched, no_weight_decay_cond=getattr(model_config, "hyena_no_weight_decay_cond_fn", None) ) - opt.connect(model) - - # Remove earlier warmup and hook logic; first-batch blocking is sufficient. + llm.train(model, data_module, trainer, log=nemo_logger, resume=auto_resume, optim=opt, tokenizer="data") - # Start training - trainer.fit(model, data_module) return trainer diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py index 50ea389bf9..8b83c2230d 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py @@ -14,11 +14,35 @@ # limitations under the License. import gc +import os import torch from lightning.pytorch import Callback +class _FirstBatchCudaSync(Callback): + # TEMPORARY CALLBACK. Remove once bug is fixed. + # First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition + # See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details. + def __init__(self): + self._done = False + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if not self._done and torch.cuda.is_available(): + torch.cuda.synchronize() + + def on_after_backward(self, trainer, pl_module): + if not self._done and torch.cuda.is_available(): + torch.cuda.synchronize() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self._done and torch.cuda.is_available(): + torch.cuda.synchronize() + # Unset blocking for subsequent batches + os.environ.pop("CUDA_LAUNCH_BLOCKING", None) + self._done = True + + class GarbageCollectAtInferenceTime(Callback): """Callback to clean up CUDA memory before validation to prevent initialization errors."""