diff --git a/msamp/nn/distributed.py b/msamp/nn/distributed.py index cce3bb1e..e88c93e9 100644 --- a/msamp/nn/distributed.py +++ b/msamp/nn/distributed.py @@ -11,7 +11,6 @@ from msamp.common.tensor import ScalingTensor, ScalingMeta from msamp.common.dtype import Dtypes, Floating from msamp.common.utils import TransformerEngineWrapper -from msamp.nn.state import model_state from msamp.operators.dist_op import DistOp @@ -246,7 +245,6 @@ def __init__(self, module, **kwargs): self.scaling_tensor_reducer = _ScalingTensorReducer( scaling_params, self.process_group, self.bucket_bytes_cap ) - model_state.use_fp8_ddp = True def forward(self, *inputs, **kwargs): """Apply _DDPSink in forward function. @@ -255,7 +253,7 @@ def forward(self, *inputs, **kwargs): inputs (tuple): The input tensors. kwargs (dict): The keyword arguments. """ - if model_state.use_fp8_ddp and torch.is_grad_enabled(): + if torch.is_grad_enabled(): inputs = _DDPSink.apply(self.scaling_tensor_reducer, torch.tensor([], requires_grad=True), *inputs) out = super().forward(*inputs, **kwargs) return out diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 23d023cb..726fc5e2 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -109,6 +109,7 @@ def backward(ctx, output_grad): use_split_accumulator=True, ) del old_wgrad + if hasattr(ctx, 'return_wgrad') and ctx.return_wgrad: wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) wgrad = wgrad.value.view(-1).view(dtype=torch.float32) diff --git a/msamp/nn/linear.py b/msamp/nn/linear.py index c33f3fa8..c5a6ca63 100644 --- a/msamp/nn/linear.py +++ b/msamp/nn/linear.py @@ -183,8 +183,8 @@ def replace(cls, model, weight_qtype=Dtypes.kfloat16, src_rank=0, group=None): for k, p in fp8_named_weights: p._param_name = k - # DDP ignores the FP8 weights, and the optimizer provides a function `optimizer.all_reduce_grads(model)` - # to sync them. + # The native DDP ignores the FP8 weights, + # and msamp.nn.distributed.FP8DistributedDataParallel will handle them. fp8_names = [] for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): diff --git a/msamp/nn/state.py b/msamp/nn/state.py index 54edf5d2..6f382428 100644 --- a/msamp/nn/state.py +++ b/msamp/nn/state.py @@ -21,7 +21,6 @@ def __init__(self): # OrderedDict[str, dict[str, ScalingMeta]], store the local scaling metas in all FP8Linear modules. # key is module name, value is scaling_metas in FP8Linear module. self._local_scaling_metas = OrderedDict() - self._use_fp8_ddp = False @property def ready_to_scale_tensor(self): @@ -42,20 +41,6 @@ def flattened_scaling_metas(self): """Decoration function to access _flattened_scaling_metas variable.""" return self._flattened_scaling_metas - @property - def use_fp8_ddp(self): - """Decoration function to access _use_fp8_ddp variable.""" - return self._use_fp8_ddp - - @use_fp8_ddp.setter - def use_fp8_ddp(self, value): - """Set the value of _use_fp8_ddp variable. - - Args: - value (bool): Value to set. - """ - self._use_fp8_ddp = value - @flattened_scaling_metas.setter def flattened_scaling_metas(self, value): """Set the value of _flattened_scaling_metas variable. diff --git a/msamp/optim/optimizer.py b/msamp/optim/optimizer.py index 640461fd..2fbe449f 100644 --- a/msamp/optim/optimizer.py +++ b/msamp/optim/optimizer.py @@ -13,8 +13,7 @@ from msamp.common.dtype import Floating from msamp.common.tensor import ScalingTensor, ScalingMeta -from msamp.common.tensor import TensorDist -from msamp.nn import model_state, ScalingParameter +from msamp.nn import model_state class LBOptimizer(Optimizer): @@ -42,14 +41,6 @@ def step(self, closure=None): self._update_scaling_factors() return rtn - def all_reduce_grads(self, model): - """All-reduce gradients of parameters.""" - if model_state.use_fp8_ddp: - return - scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)] - grads = [p.grad for p in scaling_params if p.grad is not None] - TensorDist.all_reduce_avg(grads) - def lb_step(self, closure=None): """Performs a single optimization step. The subclass needs to implement this method. diff --git a/tests/nn/test_linear.py b/tests/nn/test_linear.py index 6fcc9805..9903fdb0 100644 --- a/tests/nn/test_linear.py +++ b/tests/nn/test_linear.py @@ -56,7 +56,7 @@ def test_fp8linear_backward(self): self.assertTrue(torch.equal(fp8linear.bias.grad, linear.bias.grad)) # check weight. - self.assertTrue(isinstance(fp8linear.weight.grad, ScalingTensor)) + self.assertTrue(isinstance(fp8linear.weight.grad, torch.Tensor)) self.assertTrue(fp8linear.weight.grad.size() == linear.weight.grad.size()) @decorator.cuda_test diff --git a/tests/optim/test_adamw.py b/tests/optim/test_adamw.py index f38ad360..8f356f47 100644 --- a/tests/optim/test_adamw.py +++ b/tests/optim/test_adamw.py @@ -11,7 +11,7 @@ from functools import partial from msamp.common.dtype import Dtypes -from msamp.common.tensor import TensorDist, ScalingTensor +from msamp.common.tensor import ScalingTensor from msamp.optim import LBAdamW, LBAdam, LBAdamWBase, DSAdam from msamp.nn import LinearReplacer from tests.helper import decorator @@ -78,37 +78,11 @@ def check_optimizer_step(self, optimizer_class1, optimizer_class2, diff=3e-4): for _ in range(steps): output = model2(input) output.sum().backward() - opt2.all_reduce_grads(model2) opt2.step() opt2.zero_grad() self.assertTrue(torch.allclose(model1.weight, model2.weight.float(), 0, diff)) - def test_all_reduce_grads(self): - """Test the function `all_reduce_grads`.""" - input = torch.randn(4, 4, device='cuda') - model1 = torch.nn.Linear(4, 4).cuda() - model2 = torch.nn.Linear(4, 4).cuda() - model1 = LinearReplacer.replace(model1, Dtypes.kfloat16) - model2 = LinearReplacer.replace(model2, Dtypes.kfloat16) - opt = LBAdamW(list(model1.parameters()) + list(model2.parameters())) - loss = (model1(input) + model2(input)).sum() - loss.backward() - old_all_reduce_avg = TensorDist.all_reduce_avg - num_grads = 0 - - def debug_all_reduce_avg(grads): - nonlocal num_grads - num_grads += len(grads) - return old_all_reduce_avg(grads) - - TensorDist.all_reduce_avg = debug_all_reduce_avg - opt.all_reduce_grads(model1) - self.assertEqual(num_grads, 1) - opt.all_reduce_grads(model2) - self.assertEqual(num_grads, 2) - TensorDist.all_reduce_avg = old_all_reduce_avg - def check_optimizer_state_dict(self, lbadam_class): """Save and load state dict of lbadam_class optimizer and check if the value is excepted. @@ -127,7 +101,6 @@ def check_optimizer_state_dict(self, lbadam_class): output = model1(input) opt1.zero_grad() output.sum().backward() - opt1.all_reduce_grads(model1) opt1.step() state_dict1 = opt1.state_dict() @@ -158,7 +131,6 @@ def check_optimizer_state_dict(self, lbadam_class): state_dict2 = copy.deepcopy(state_dict1) opt1.zero_grad() model1(input).sum().backward() - opt1.all_reduce_grads(model1) opt1.step() # Build model2 and update 4 times. @@ -171,7 +143,6 @@ def check_optimizer_state_dict(self, lbadam_class): output = model2(input) opt2.zero_grad() output.sum().backward() - opt2.all_reduce_grads(model2) opt2.step() # Load state dict to op2 and check if the weight is same as model1 after update weigth once. @@ -180,7 +151,6 @@ def check_optimizer_state_dict(self, lbadam_class): opt2.zero_grad() model2(input).sum().backward() - opt2.all_reduce_grads(model2) opt2.step() self.assertTrue(torch.equal(model1.weight.value, model2.weight.value)) @@ -216,5 +186,4 @@ def test_historical_window_quantization(self): y = model(x) self.assertTrue((model.scaling_metas['input'].amax.max() == max(windows)).all()) y.sum().backward() - opt.all_reduce_grads(model) opt.step()