Skip to content
Open
5 changes: 4 additions & 1 deletion tester/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def test(self):
return

try:
device = torch.device("cuda:0")
if paddle.device.get_device() == "cpu":
device = torch.device("cpu")
else:
device = torch.device("cuda:0")
torch.set_default_device(device)
if not self.gen_torch_input():
print("gen_torch_input failed", flush=True)
Expand Down
6 changes: 3 additions & 3 deletions tester/paddle_to_torch/mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2739,8 +2739,8 @@
"weight": "None",
"reduction": "'mean'",
"soft_label": "False",
"label_smoothing": "0.0",
"reduction_original": "None"
"use_softmax": "True",
"label_smoothing": "0.0"
},
"paddle_torch_args_map": {
"input": "input",
Expand Down Expand Up @@ -4906,4 +4906,4 @@
"paddle.Tensor.zero_": {
"torch_api": "torch.Tensor.zero_"
}
}
}
86 changes: 66 additions & 20 deletions tester/paddle_to_torch/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,42 +696,84 @@ def apply(self, paddle_api: str) -> ConvertResult:
defaults_code, map_code = self.apply_generic()
pre = """
shp = label.shape
if len(input.shape) > 2:
perm = [0] + [len(input.shape)-1]+ [i for i in range(1,len(input.shape)-1)]
axis = locals().get('axis', -1)
use_softmax = locals().get('use_softmax', True)
if axis < 0:
axis += input.dim()
_manual_soft_label_ce = soft_label and use_softmax and axis == input.dim() - 1
if not use_softmax:
input = torch.nn.functional.softmax(input, dim=axis)
if len(input.shape) > 2 and not _manual_soft_label_ce:
perm = [0] + [len(input.shape)-1] + [i for i in range(1, len(input.shape)-1)]
input = input.permute(*perm)
axis = locals().get('axis',-1)
label = label.squeeze(-1)
if weight is not None:
weight.requires_grad = False
if label.dtype == torch.int32:
label = label.long()
if soft_label and weight is not None and shp == input.shape:
reduction_original = reduction
weight_original = weight
_manual_weight = soft_label and weight is not None and torch.is_floating_point(label)
if _manual_soft_label_ce:
_original_label = label
_did_onehot_from_int = False
if label_smoothing > 0.0 and not torch.is_floating_point(label):
_original_label = torch.nn.functional.one_hot(label.long(), num_classes=input.shape[-1]).float()
label = _original_label
_did_onehot_from_int = True
if torch.is_floating_point(label) and not _did_onehot_from_int:
label = label.to(dtype=input.dtype)
_original_label = _original_label.to(dtype=input.dtype)
if label_smoothing > 0.0:
label = label * (1.0 - label_smoothing) + label_smoothing / input.shape[-1]
if _did_onehot_from_int:
label = label.to(dtype=input.dtype)
_original_label = _original_label.to(dtype=input.dtype)
_weighted_label = label
_manual_weight = weight is not None
if _manual_weight:
_saved_reduction = reduction
_saved_weight = weight
reduction = "none"
weight = None
"""
core = f"""
result = {self.torch_api}(**_kwargs)
if not use_softmax and not soft_label and label_smoothing == 0.0:
result = torch.nn.functional.nll_loss(
input=torch.log(_kwargs["input"]),
target=_kwargs["target"],
weight=_kwargs.get("weight"),
ignore_index=_kwargs.get("ignore_index", -100),
reduction=_kwargs.get("reduction", "mean"),
)
elif _manual_soft_label_ce:
_log_prob = torch.nn.functional.log_softmax(input, dim=axis)
result = -(label * _log_prob).sum(dim=axis)
else:
if not use_softmax:
_kwargs["input"] = torch.log(_kwargs["input"])
result = {self.torch_api}(**_kwargs)
"""
post = """
if reduction_original is not None:
reduction = reduction_original
loss_weight = label@weight_original
sum_weight = loss_weight.sum()
result *= loss_weight
else:
sum_weight = result.numel()

if _manual_weight:
if _manual_soft_label_ce:
loss_weight = (_weighted_label.to(dtype=_saved_weight.dtype) * _saved_weight).sum(dim=axis)
else:
loss_weight = label.to(dtype=_saved_weight.dtype) @ _saved_weight
result = result * loss_weight
if _saved_reduction == "sum":
result = result.sum()
elif _saved_reduction == "mean":
result = result.sum() / loss_weight.sum()
reduction = _saved_reduction
elif _manual_soft_label_ce:
if reduction == "sum":
result = result.sum()
elif reduction == "mean":
result = result.mean()
if reduction == "none":
if soft_label:
result = result.unsqueeze(-1)
else:
result = result.reshape(shp)
elif reduction == "sum":
result = result.sum()
else:
result = result.sum()/sum_weight
"""
code = Code(
preprocess=defaults_code + pre.splitlines() + map_code,
Expand Down Expand Up @@ -3210,7 +3252,11 @@ def apply(self, paddle_api: str) -> ConvertResult:
result = result.permute(0, 2, 1)
"""
code = Code(
preprocess=default_code + pre.splitlines() + map_code,
preprocess=default_code + pre.splitlines() + map_code + [
'# PyTorch rejects align_corners for nearest/area modes',
'if _kwargs.get("mode", "nearest") in ("nearest", "area"):',
' _kwargs.pop("align_corners", None)',
],
core=[core],
postprocess=post.splitlines(),
)
Expand Down
Loading