diff --git a/tester/accuracy.py b/tester/accuracy.py index ccac882d..0194cf05 100644 --- a/tester/accuracy.py +++ b/tester/accuracy.py @@ -72,7 +72,10 @@ def test(self): return try: - device = torch.device("cuda:0") + if paddle.device.get_device() == "cpu": + device = torch.device("cpu") + else: + device = torch.device("cuda:0") torch.set_default_device(device) if not self.gen_torch_input(): print("gen_torch_input failed", flush=True) diff --git a/tester/paddle_to_torch/mapping.json b/tester/paddle_to_torch/mapping.json index 6a915931..8420563b 100644 --- a/tester/paddle_to_torch/mapping.json +++ b/tester/paddle_to_torch/mapping.json @@ -2739,8 +2739,8 @@ "weight": "None", "reduction": "'mean'", "soft_label": "False", - "label_smoothing": "0.0", - "reduction_original": "None" + "use_softmax": "True", + "label_smoothing": "0.0" }, "paddle_torch_args_map": { "input": "input", @@ -4906,4 +4906,4 @@ "paddle.Tensor.zero_": { "torch_api": "torch.Tensor.zero_" } -} \ No newline at end of file +} diff --git a/tester/paddle_to_torch/rules.py b/tester/paddle_to_torch/rules.py index ac4bde45..5961f297 100644 --- a/tester/paddle_to_torch/rules.py +++ b/tester/paddle_to_torch/rules.py @@ -696,42 +696,84 @@ def apply(self, paddle_api: str) -> ConvertResult: defaults_code, map_code = self.apply_generic() pre = """ shp = label.shape -if len(input.shape) > 2: - perm = [0] + [len(input.shape)-1]+ [i for i in range(1,len(input.shape)-1)] +axis = locals().get('axis', -1) +use_softmax = locals().get('use_softmax', True) +if axis < 0: + axis += input.dim() +_manual_soft_label_ce = soft_label and use_softmax and axis == input.dim() - 1 +if not use_softmax: + input = torch.nn.functional.softmax(input, dim=axis) +if len(input.shape) > 2 and not _manual_soft_label_ce: + perm = [0] + [len(input.shape)-1] + [i for i in range(1, len(input.shape)-1)] input = input.permute(*perm) -axis = locals().get('axis',-1) label = label.squeeze(-1) if weight is not None: weight.requires_grad = False if label.dtype == torch.int32: label = label.long() -if soft_label and weight is not None and shp == input.shape: - reduction_original = reduction - weight_original = weight +_manual_weight = soft_label and weight is not None and torch.is_floating_point(label) +if _manual_soft_label_ce: + _original_label = label + _did_onehot_from_int = False + if label_smoothing > 0.0 and not torch.is_floating_point(label): + _original_label = torch.nn.functional.one_hot(label.long(), num_classes=input.shape[-1]).float() + label = _original_label + _did_onehot_from_int = True + if torch.is_floating_point(label) and not _did_onehot_from_int: + label = label.to(dtype=input.dtype) + _original_label = _original_label.to(dtype=input.dtype) + if label_smoothing > 0.0: + label = label * (1.0 - label_smoothing) + label_smoothing / input.shape[-1] + if _did_onehot_from_int: + label = label.to(dtype=input.dtype) + _original_label = _original_label.to(dtype=input.dtype) + _weighted_label = label + _manual_weight = weight is not None +if _manual_weight: + _saved_reduction = reduction + _saved_weight = weight reduction = "none" weight = None """ core = f""" -result = {self.torch_api}(**_kwargs) +if not use_softmax and not soft_label and label_smoothing == 0.0: + result = torch.nn.functional.nll_loss( + input=torch.log(_kwargs["input"]), + target=_kwargs["target"], + weight=_kwargs.get("weight"), + ignore_index=_kwargs.get("ignore_index", -100), + reduction=_kwargs.get("reduction", "mean"), + ) +elif _manual_soft_label_ce: + _log_prob = torch.nn.functional.log_softmax(input, dim=axis) + result = -(label * _log_prob).sum(dim=axis) +else: + if not use_softmax: + _kwargs["input"] = torch.log(_kwargs["input"]) + result = {self.torch_api}(**_kwargs) """ post = """ -if reduction_original is not None: - reduction = reduction_original - loss_weight = label@weight_original - sum_weight = loss_weight.sum() - result *= loss_weight -else: - sum_weight = result.numel() - +if _manual_weight: + if _manual_soft_label_ce: + loss_weight = (_weighted_label.to(dtype=_saved_weight.dtype) * _saved_weight).sum(dim=axis) + else: + loss_weight = label.to(dtype=_saved_weight.dtype) @ _saved_weight + result = result * loss_weight + if _saved_reduction == "sum": + result = result.sum() + elif _saved_reduction == "mean": + result = result.sum() / loss_weight.sum() + reduction = _saved_reduction +elif _manual_soft_label_ce: + if reduction == "sum": + result = result.sum() + elif reduction == "mean": + result = result.mean() if reduction == "none": if soft_label: result = result.unsqueeze(-1) else: result = result.reshape(shp) -elif reduction == "sum": - result = result.sum() -else: - result = result.sum()/sum_weight """ code = Code( preprocess=defaults_code + pre.splitlines() + map_code, @@ -3210,7 +3252,11 @@ def apply(self, paddle_api: str) -> ConvertResult: result = result.permute(0, 2, 1) """ code = Code( - preprocess=default_code + pre.splitlines() + map_code, + preprocess=default_code + pre.splitlines() + map_code + [ + '# PyTorch rejects align_corners for nearest/area modes', + 'if _kwargs.get("mode", "nearest") in ("nearest", "area"):', + ' _kwargs.pop("align_corners", None)', + ], core=[core], postprocess=post.splitlines(), )