Skip to content

non parallelized basic validator implementation [WIP] #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,17 @@ def build_test_list():
"gradient_accumulation",
ngpu=2,
),
OverrideDefinitions(
[
[
"--validation.enabled",
"--validation.dataset c4_test",
],
],
"Validation test no parallelism",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is not without parallelism -- you are doing data parallel for validation; however, you are not doing all-reduce on the loss, so the loss you print out would be different on each DP rank. Let's do that in this PR, following the code in model forward.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L451-L464

For that you'll need to pass in parallel_dims world_mesh ft_manager when constructing Validator

I think then the code will support Tensor Parallel and Context Parallel but not Pipeline Parallel yet, which we can do in a followup PR.

"validation_no_parallel",
ngpu=1,
),
]
return integration_tests_flavors

Expand Down
131 changes: 131 additions & 0 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
from torchtitan.tools import utils
from torchtitan.tools.logging import logger


class BaseValidator:
def __init__(self, job_config: JobConfig):
self.job_config = job_config

def validate(self, model_parts: list[nn.Module]) -> dict[str, float]:
raise NotImplementedError("validate method not implemented")


class Validator(BaseValidator):
"""
Simple validator focused on correctness and integration.

Args:
job_config: Job configuration
validation_dataloader: The validation dataloader
loss_fn: Loss function to use for validation
model: The model to validate (single model, no parallelism)
"""

validation_dataloader: BaseDataLoader

def __init__(
self,
job_config: JobConfig,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
loss_fn: LossFunction,
):
self.job_config = job_config
self.loss_fn = loss_fn
self.validation_dataloader = build_hf_validation_dataloader(
job_config=job_config,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
infinite=False,
)

def should_validate(self, step: int) -> bool:
return step % self.job_config.validation.val_freq == 0

def validate(
self,
model_parts: list[nn.Module],
) -> dict[str, float]:
# Set model to eval mode
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO: here claiming we only support data parallel for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not support all parallelisms besides PP here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about reshard_after_forward value as ensure the parameters are sharded before leaving validate(). Otherwise, checkpointing will be broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, I forgot this.
Please include & adapt the following code in Validator.validate()
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L185-L189
cc @wwwjn

model.eval()

total_loss = 0.0
num_batches = 0
device_type = utils.device_type
num_val_steps = 0

with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can also use this as a decorator instead, so you don't have to indent your code as much.

@torch.no_grad()
def validate(

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you don't need this try-catch because StopIteration will be automatically captured by for loop safely.

for input_dict, labels in self.validation_dataloader:

if (
self.job_config.validation.val_steps != -1
and num_val_steps >= self.job_config.validation.val_steps
):
break

for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
labels = labels.to(device_type)

inputs = input_dict["input"]
predictions = model(inputs)
loss = self.loss_fn(predictions, labels)

total_loss += loss.item()
num_batches += 1

num_val_steps += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you use separate counters for num_batches and num_val_steps? Also, you could use this instead:

for step, (input_dict, labels) in enumerate(self.validation_dataloader):

Here, step replaces num_batches and num_val_steps. You would also have to change num_val_steps >= self.job_config.validation.val_steps to step > self.job_config.validation.val_steps above.


except StopIteration:
logger.info("Validation dataloader exhausted")

# Compute average loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the number of tokens is the same for every batch. Maybe either manually keep track of the total number of tokens or at least add a NOTE that highlights this assumption.

if num_batches > 0:
average_loss = total_loss / num_batches
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a __post_init__ to the Validation dataclass that verifies that all values are valid, e.g., val_steps > 0.

average_loss = 0.0
logger.warning("No validation batches processed")

logger.info(
f"Validation completed. Average loss: {average_loss:.4f} over {num_batches} batches"
)

# Set model back to train mode
model.train()

return {"validation_loss": average_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average_loss is the local loss for each rank, but should still be all-reduced across ranks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change this to "validation/loss"? This is important for how wandb represents the metrics and allows you to add more metrics to the same section via "validation/<you-new-metric>" later on.



def build_validator(
job_config: JobConfig,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
loss_fn: LossFunction,
) -> BaseValidator:
"""Build a simple validator focused on correctness."""
return Validator(
job_config=job_config,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
loss_fn=loss_fn,
)
25 changes: 25 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,30 @@ class Experimental:
"""


@dataclass
class Validation:
enabled: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove this field and modify val_freq to offer an option for disabling validation, e.g., val_freq: int | None = 10, where validation is disabled if val_freq=None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this enabled to be consistent with other configs in torchtitan -- this sounds more like a style thing?

"""Enable validation to default run validation after each training loop"""

dataset: str = "c4_validation"
"""Dataset to use for validation"""

dataset_path: str | None = None
"""Path to dataset to use for validation"""

local_batch_size: int = 8
"""Batch size for validation"""

seq_len: int = 2048
"""Sequence length for validation"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set up a steps config, controlling how many iterations we run, default to -1 which means consuming all the data in the validation dataset

val_freq: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to have the val_ prefix as it's not ambiguous under Validation

Suggested change
val_freq: int = 1
freq: int = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe default to 10

"""Frequency of validation"""

val_steps: int = -1
"""Number of validation steps, -1 means all steps"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Number of validation steps, -1 means all steps"""
"""Number of validation steps, -1 means consuming all the data in the validation dataset"""



@dataclass
class JobConfig:
"""
Expand All @@ -681,6 +705,7 @@ class JobConfig:
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
experimental: Experimental = field(default_factory=Experimental)
validation: Validation = field(default_factory=Validation)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
Expand Down
38 changes: 38 additions & 0 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class DatasetConfig:
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
"c4_validation": DatasetConfig(
path="allenai/c4",
loader=lambda path: load_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can reuse _load_c4_dataset together with functools.partial here by adding split as an argument to _load_c4_dataset.

path, name="en", split="validation", streaming=True
),
text_processor=_process_c4_text,
),
}


Expand Down Expand Up @@ -193,3 +200,34 @@ def build_hf_dataloader(
dp_world_size=dp_world_size,
batch_size=batch_size,
)


def build_hf_validation_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think adding a new function for this is necessary; I would prefer replacing the job_config argument with dataset_name, dataset_path, batch_size, and seq_len. The reasoning is that for validation the function is also just returning a data loader based on a HF dataset, just the underlying dataset will be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we probably don't want to change this interface is:
people plug in their own data loader, and they want it to be general enough https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L133-L138
We used to have something closer to what you proposed, but changed due to their requests.

I think it's ok we make this compromise for that purpose.

dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
infinite: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this arg -- I don't think anyone wants to do multiple loops over the validation dataset

) -> ParallelAwareDataloader:
"""Build a validation data loader for HuggingFace datasets."""
dataset_name = job_config.validation.dataset
dataset_path = job_config.validation.dataset_path
batch_size = job_config.validation.local_batch_size
seq_len = job_config.validation.seq_len

hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you can always set to False here

)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
2 changes: 2 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.validate import build_validator
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -81,5 +82,6 @@
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
)
)
7 changes: 7 additions & 0 deletions torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
disable_loss_parallel = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change?


[checkpoint]
enable_checkpoint = false
Expand All @@ -71,3 +72,9 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[validation]
enabled = false
dataset = "c4_validation"
val_freq = 5
val_steps = 10
3 changes: 3 additions & 0 deletions torchtitan/protocols/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchtitan.components.metrics import MetricsProcessor
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.components.validate import BaseValidator
from torchtitan.config_manager import JobConfig

DeviceType = int | str | torch.device
Expand Down Expand Up @@ -77,6 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
[OptimizersContainer, JobConfig], LRSchedulersContainer
]
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
ValidatorBuilder: TypeAlias = Callable[..., BaseValidator]


@dataclass
Expand All @@ -91,6 +93,7 @@ class TrainSpec:
build_dataloader_fn: DataLoaderBuilder
build_tokenizer_fn: TokenizerBuilder | None
build_loss_fn: LossFunctionBuilder
build_validator_fn: ValidatorBuilder | None = None
build_metrics_processor_fn: MetricsProcessorBuilder | None = None


Expand Down
20 changes: 20 additions & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
optimizers: train_spec_module.OptimizersContainer
lr_schedulers: train_spec_module.LRSchedulersContainer

validator: train_spec_module.BaseValidator | None

pp_has_first_stage: bool
pp_has_last_stage: bool

Expand Down Expand Up @@ -319,6 +321,18 @@ def __init__(self, job_config: JobConfig):
device_type,
)

# Build validator if validation is configured
if job_config.validation.enabled:
assert self.train_spec.build_validator_fn is not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you raise an error here if parallel_dims.pp_enabled?

self.validator = self.train_spec.build_validator_fn(
job_config=job_config,
dp_world_size=dp_degree,
dp_rank=dp_rank,
tokenizer=tokenizer,
loss_fn=self.loss_fn,
)

logger.info(
"Trainer is initialized with "
f"local batch size {job_config.training.local_batch_size}, "
Expand Down Expand Up @@ -463,6 +477,12 @@ def train_step(
else:
global_avg_loss = global_max_loss = loss.detach().item()

# Run validation if validator is available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is not part of training step, let's put this outside train_step and put it in train before self.checkpointer.save(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As my comment above, please check if validation breaks checkpointing if FSDP is used.

if self.job_config.validation.enabled and self.validator.should_validate(
self.step
):
validation_metrics = self.validator.validate(self.model_parts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation metrics should be logged by self.metrics_processor.log() (to the terminal output and Tensorboard/wandb).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For logging to TB/W&B I agree we should use self.metrics_processor.logger.
For terminal, there are two options, one is to do it locally in this function, the other is creating a new metrics processor like what you did. I personally think the latter tries to make the style consistent (which I appreciate), but sounds a bit overkill.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a separate metrics processor seems like overkill for this implementation. A third option is to also use self.metrics_processor.log to print the metrics in the terminal.


self.metrics_processor.log(
self.step,
global_avg_loss,
Expand Down
Loading