diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 3ccbc1890..dfe1d86f9 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -509,6 +509,17 @@ def build_test_list(): "gradient_accumulation", ngpu=2, ), + OverrideDefinitions( + [ + [ + "--validation.enabled", + "--validation.dataset c4_test", + ], + ], + "Validation test no parallelism", + "validation_no_parallel", + ngpu=1, + ), ] return integration_tests_flavors diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py new file mode 100644 index 000000000..fb1d5cabb --- /dev/null +++ b/torchtitan/components/validate.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from torchtitan.components.dataloader import BaseDataLoader +from torchtitan.components.loss import LossFunction +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader +from torchtitan.tools import utils +from torchtitan.tools.logging import logger + + +class BaseValidator: + def __init__(self, job_config: JobConfig): + self.job_config = job_config + + def validate(self, model_parts: list[nn.Module]) -> dict[str, float]: + raise NotImplementedError("validate method not implemented") + + +class Validator(BaseValidator): + """ + Simple validator focused on correctness and integration. + + Args: + job_config: Job configuration + validation_dataloader: The validation dataloader + loss_fn: Loss function to use for validation + model: The model to validate (single model, no parallelism) + """ + + validation_dataloader: BaseDataLoader + + def __init__( + self, + job_config: JobConfig, + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + loss_fn: LossFunction, + ): + self.job_config = job_config + self.loss_fn = loss_fn + self.validation_dataloader = build_hf_validation_dataloader( + job_config=job_config, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + infinite=False, + ) + + def should_validate(self, step: int) -> bool: + return step % self.job_config.validation.val_freq == 0 + + def validate( + self, + model_parts: list[nn.Module], + ) -> dict[str, float]: + # Set model to eval mode + model = model_parts[0] + model.eval() + + total_loss = 0.0 + num_batches = 0 + device_type = utils.device_type + num_val_steps = 0 + + with torch.no_grad(): + try: + for input_dict, labels in self.validation_dataloader: + + if ( + self.job_config.validation.val_steps != -1 + and num_val_steps >= self.job_config.validation.val_steps + ): + break + + for k, v in input_dict.items(): + input_dict[k] = v.to(device_type) + labels = labels.to(device_type) + + inputs = input_dict["input"] + predictions = model(inputs) + loss = self.loss_fn(predictions, labels) + + total_loss += loss.item() + num_batches += 1 + + num_val_steps += 1 + + except StopIteration: + logger.info("Validation dataloader exhausted") + + # Compute average loss + if num_batches > 0: + average_loss = total_loss / num_batches + else: + average_loss = 0.0 + logger.warning("No validation batches processed") + + logger.info( + f"Validation completed. Average loss: {average_loss:.4f} over {num_batches} batches" + ) + + # Set model back to train mode + model.train() + + return {"validation_loss": average_loss} + + +def build_validator( + job_config: JobConfig, + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + loss_fn: LossFunction, +) -> BaseValidator: + """Build a simple validator focused on correctness.""" + return Validator( + job_config=job_config, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + loss_fn=loss_fn, + ) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3f8d25688..5caa15fbf 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -657,6 +657,30 @@ class Experimental: """ +@dataclass +class Validation: + enabled: bool = False + """Enable validation to default run validation after each training loop""" + + dataset: str = "c4_validation" + """Dataset to use for validation""" + + dataset_path: str | None = None + """Path to dataset to use for validation""" + + local_batch_size: int = 8 + """Batch size for validation""" + + seq_len: int = 2048 + """Sequence length for validation""" + + val_freq: int = 1 + """Frequency of validation""" + + val_steps: int = -1 + """Number of validation steps, -1 means all steps""" + + @dataclass class JobConfig: """ @@ -681,6 +705,7 @@ class JobConfig: memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation) fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) experimental: Experimental = field(default_factory=Experimental) + validation: Validation = field(default_factory=Validation) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 023b4a29e..818f66104 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -49,6 +49,13 @@ class DatasetConfig: loader=lambda path: load_dataset(path, split="train"), text_processor=_process_c4_text, ), + "c4_validation": DatasetConfig( + path="allenai/c4", + loader=lambda path: load_dataset( + path, name="en", split="validation", streaming=True + ), + text_processor=_process_c4_text, + ), } @@ -193,3 +200,34 @@ def build_hf_dataloader( dp_world_size=dp_world_size, batch_size=batch_size, ) + + +def build_hf_validation_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a validation data loader for HuggingFace datasets.""" + dataset_name = job_config.validation.dataset + dataset_path = job_config.validation.dataset_path + batch_size = job_config.validation.local_batch_size + seq_len = job_config.validation.seq_len + + hf_ds = HuggingFaceDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 7cfc03b4a..eec35cbf1 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -9,6 +9,7 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.validate import build_validator from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -81,5 +82,6 @@ build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_tiktoken_tokenizer, build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, ) ) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 52a362e1a..7cefaca39 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -54,6 +54,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +disable_loss_parallel = true [checkpoint] enable_checkpoint = false @@ -71,3 +72,9 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +val_freq = 5 +val_steps = 10 diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index ddb6961e5..98158bca9 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -22,6 +22,7 @@ from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import Tokenizer +from torchtitan.components.validate import BaseValidator from torchtitan.config_manager import JobConfig DeviceType = int | str | torch.device @@ -77,6 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: [OptimizersContainer, JobConfig], LRSchedulersContainer ] LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] +ValidatorBuilder: TypeAlias = Callable[..., BaseValidator] @dataclass @@ -91,6 +93,7 @@ class TrainSpec: build_dataloader_fn: DataLoaderBuilder build_tokenizer_fn: TokenizerBuilder | None build_loss_fn: LossFunctionBuilder + build_validator_fn: ValidatorBuilder | None = None build_metrics_processor_fn: MetricsProcessorBuilder | None = None diff --git a/torchtitan/train.py b/torchtitan/train.py index ca1480f2e..7f8c723a9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -52,6 +52,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): optimizers: train_spec_module.OptimizersContainer lr_schedulers: train_spec_module.LRSchedulersContainer + validator: train_spec_module.BaseValidator | None + pp_has_first_stage: bool pp_has_last_stage: bool @@ -319,6 +321,18 @@ def __init__(self, job_config: JobConfig): device_type, ) + # Build validator if validation is configured + if job_config.validation.enabled: + assert self.train_spec.build_validator_fn is not None + + self.validator = self.train_spec.build_validator_fn( + job_config=job_config, + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=tokenizer, + loss_fn=self.loss_fn, + ) + logger.info( "Trainer is initialized with " f"local batch size {job_config.training.local_batch_size}, " @@ -463,6 +477,12 @@ def train_step( else: global_avg_loss = global_max_loss = loss.detach().item() + # Run validation if validator is available + if self.job_config.validation.enabled and self.validator.should_validate( + self.step + ): + validation_metrics = self.validator.validate(self.model_parts) + self.metrics_processor.log( self.step, global_avg_loss,