-
Notifications
You must be signed in to change notification settings - Fork 420
non parallelized basic validator implementation [WIP] #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
29ab77c
ea30e19
a21e119
eb07b37
f27049d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from torchtitan.components.dataloader import BaseDataLoader | ||
from torchtitan.components.loss import LossFunction | ||
from torchtitan.components.tokenizer import Tokenizer | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader | ||
from torchtitan.tools import utils | ||
from torchtitan.tools.logging import logger | ||
|
||
|
||
class BaseValidator: | ||
def __init__(self, job_config: JobConfig): | ||
self.job_config = job_config | ||
|
||
def validate(self, model_parts: list[nn.Module]) -> dict[str, float]: | ||
raise NotImplementedError("validate method not implemented") | ||
|
||
|
||
class Validator(BaseValidator): | ||
""" | ||
Simple validator focused on correctness and integration. | ||
|
||
Args: | ||
job_config: Job configuration | ||
validation_dataloader: The validation dataloader | ||
loss_fn: Loss function to use for validation | ||
model: The model to validate (single model, no parallelism) | ||
""" | ||
|
||
validation_dataloader: BaseDataLoader | ||
|
||
def __init__( | ||
self, | ||
job_config: JobConfig, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
loss_fn: LossFunction, | ||
): | ||
self.job_config = job_config | ||
self.loss_fn = loss_fn | ||
self.validation_dataloader = build_hf_validation_dataloader( | ||
job_config=job_config, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
tokenizer=tokenizer, | ||
infinite=False, | ||
) | ||
|
||
def should_validate(self, step: int) -> bool: | ||
return step % self.job_config.validation.val_freq == 0 | ||
|
||
def validate( | ||
self, | ||
model_parts: list[nn.Module], | ||
) -> dict[str, float]: | ||
# Set model to eval mode | ||
model = model_parts[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to not support all parallelisms besides PP here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great point, I forgot this. |
||
model.eval() | ||
|
||
total_loss = 0.0 | ||
num_batches = 0 | ||
device_type = utils.device_type | ||
num_val_steps = 0 | ||
|
||
with torch.no_grad(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: you can also use this as a decorator instead, so you don't have to indent your code as much. @torch.no_grad()
def validate( |
||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe you don't need this try-catch because StopIteration will be automatically captured by for loop safely. |
||
for input_dict, labels in self.validation_dataloader: | ||
|
||
if ( | ||
self.job_config.validation.val_steps != -1 | ||
and num_val_steps >= self.job_config.validation.val_steps | ||
): | ||
break | ||
|
||
for k, v in input_dict.items(): | ||
input_dict[k] = v.to(device_type) | ||
labels = labels.to(device_type) | ||
|
||
inputs = input_dict["input"] | ||
predictions = model(inputs) | ||
loss = self.loss_fn(predictions, labels) | ||
|
||
total_loss += loss.item() | ||
num_batches += 1 | ||
|
||
num_val_steps += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason you use separate counters for for step, (input_dict, labels) in enumerate(self.validation_dataloader): Here, |
||
|
||
except StopIteration: | ||
logger.info("Validation dataloader exhausted") | ||
|
||
# Compute average loss | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that the number of tokens is the same for every batch. Maybe either manually keep track of the total number of tokens or at least add a NOTE that highlights this assumption. |
||
if num_batches > 0: | ||
average_loss = total_loss / num_batches | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a |
||
average_loss = 0.0 | ||
logger.warning("No validation batches processed") | ||
|
||
logger.info( | ||
f"Validation completed. Average loss: {average_loss:.4f} over {num_batches} batches" | ||
) | ||
|
||
# Set model back to train mode | ||
model.train() | ||
|
||
return {"validation_loss": average_loss} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you change this to |
||
|
||
|
||
def build_validator( | ||
job_config: JobConfig, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
loss_fn: LossFunction, | ||
) -> BaseValidator: | ||
"""Build a simple validator focused on correctness.""" | ||
return Validator( | ||
job_config=job_config, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
tokenizer=tokenizer, | ||
loss_fn=loss_fn, | ||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -657,6 +657,30 @@ class Experimental: | |||||
""" | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class Validation: | ||||||
enabled: bool = False | ||||||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could remove this field and modify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep this |
||||||
"""Enable validation to default run validation after each training loop""" | ||||||
|
||||||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
dataset: str = "c4_validation" | ||||||
"""Dataset to use for validation""" | ||||||
|
||||||
dataset_path: str | None = None | ||||||
"""Path to dataset to use for validation""" | ||||||
|
||||||
local_batch_size: int = 8 | ||||||
"""Batch size for validation""" | ||||||
|
||||||
seq_len: int = 2048 | ||||||
"""Sequence length for validation""" | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set up a |
||||||
val_freq: int = 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to have the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe default to 10 |
||||||
"""Frequency of validation""" | ||||||
|
||||||
val_steps: int = -1 | ||||||
"""Number of validation steps, -1 means all steps""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
@dataclass | ||||||
class JobConfig: | ||||||
""" | ||||||
|
@@ -681,6 +705,7 @@ class JobConfig: | |||||
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation) | ||||||
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) | ||||||
experimental: Experimental = field(default_factory=Experimental) | ||||||
validation: Validation = field(default_factory=Validation) | ||||||
|
||||||
def to_dict(self) -> dict[str, Any]: | ||||||
return asdict(self) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,13 @@ class DatasetConfig: | |
loader=lambda path: load_dataset(path, split="train"), | ||
text_processor=_process_c4_text, | ||
), | ||
"c4_validation": DatasetConfig( | ||
path="allenai/c4", | ||
loader=lambda path: load_dataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: you can reuse |
||
path, name="en", split="validation", streaming=True | ||
), | ||
text_processor=_process_c4_text, | ||
), | ||
} | ||
|
||
|
||
|
@@ -193,3 +200,34 @@ def build_hf_dataloader( | |
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) | ||
|
||
|
||
def build_hf_validation_dataloader( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think adding a new function for this is necessary; I would prefer replacing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we probably don't want to change this interface is: I think it's ok we make this compromise for that purpose. |
||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
job_config: JobConfig, | ||
infinite: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove this arg -- I don't think anyone wants to do multiple loops over the validation dataset |
||
) -> ParallelAwareDataloader: | ||
"""Build a validation data loader for HuggingFace datasets.""" | ||
dataset_name = job_config.validation.dataset | ||
dataset_path = job_config.validation.dataset_path | ||
batch_size = job_config.validation.local_batch_size | ||
seq_len = job_config.validation.seq_len | ||
|
||
hf_ds = HuggingFaceDataset( | ||
dataset_name=dataset_name, | ||
dataset_path=dataset_path, | ||
tokenizer=tokenizer, | ||
seq_len=seq_len, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
infinite=infinite, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so you can always set to |
||
) | ||
|
||
return ParallelAwareDataloader( | ||
dataset=hf_ds, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ tensor_parallel_degree = 1 | |
enable_async_tensor_parallel = false | ||
pipeline_parallel_degree = 1 | ||
context_parallel_degree = 1 | ||
disable_loss_parallel = true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert this change? |
||
|
||
[checkpoint] | ||
enable_checkpoint = false | ||
|
@@ -71,3 +72,9 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas | |
enable_fsdp_float8_all_gather = false | ||
precompute_float8_dynamic_scale_for_fsdp = false | ||
filter_fqns = ["output"] | ||
|
||
[validation] | ||
enabled = false | ||
dataset = "c4_validation" | ||
val_freq = 5 | ||
val_steps = 10 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): | |
optimizers: train_spec_module.OptimizersContainer | ||
lr_schedulers: train_spec_module.LRSchedulersContainer | ||
|
||
validator: train_spec_module.BaseValidator | None | ||
|
||
pp_has_first_stage: bool | ||
pp_has_last_stage: bool | ||
|
||
|
@@ -319,6 +321,18 @@ def __init__(self, job_config: JobConfig): | |
device_type, | ||
) | ||
|
||
# Build validator if validation is configured | ||
if job_config.validation.enabled: | ||
assert self.train_spec.build_validator_fn is not None | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you raise an error here if |
||
self.validator = self.train_spec.build_validator_fn( | ||
job_config=job_config, | ||
dp_world_size=dp_degree, | ||
dp_rank=dp_rank, | ||
tokenizer=tokenizer, | ||
loss_fn=self.loss_fn, | ||
) | ||
|
||
logger.info( | ||
"Trainer is initialized with " | ||
f"local batch size {job_config.training.local_batch_size}, " | ||
|
@@ -463,6 +477,12 @@ def train_step( | |
else: | ||
global_avg_loss = global_max_loss = loss.detach().item() | ||
|
||
# Run validation if validator is available | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as this is not part of training step, let's put this outside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As my comment above, please check if validation breaks checkpointing if FSDP is used. |
||
if self.job_config.validation.enabled and self.validator.should_validate( | ||
self.step | ||
): | ||
validation_metrics = self.validator.validate(self.model_parts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The validation metrics should be logged by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For logging to TB/W&B I agree we should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that a separate metrics processor seems like overkill for this implementation. A third option is to also use |
||
|
||
self.metrics_processor.log( | ||
self.step, | ||
global_avg_loss, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is not without parallelism -- you are doing data parallel for validation; however, you are not doing all-reduce on the loss, so the loss you print out would be different on each DP rank. Let's do that in this PR, following the code in model forward.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L451-L464
For that you'll need to pass in
parallel_dims
world_mesh
ft_manager
when constructingValidator
I think then the code will support Tensor Parallel and Context Parallel but not Pipeline Parallel yet, which we can do in a followup PR.