Skip to content

note on torch 1.11 vs torch 2.1 compatibility #117

Open
@BStudent

Description

@BStudent

(INFORMATIONAL)
Note for users of PyTorch 2.x that this example function works with PyTorch 1.11 but returns nan Loss values under PyTorch 2.1.

UPDATED:

  1. This issue only affects conversion of data to pandas DataFrame (for visualization) in the penalization_visualization() demo function.
  2. Other code through example_simple_model() demo function appears to work correctly.

Root cause: nan in predict.log() propagates to nan smoothed values when returned as crit(predict.log(), torch.LongTensor([1])).data.

Workaround: replace 0 with 1.0e-10 or similar epsilon-value for plotting (not a real solution due to masking).

  • Added context in function docstring below to assist in location.
# NOTE: return value broken WRT PyTorch 2.1, SEE CODE:
def loss(x, crit):
    """
     This function follows the text (by A-T Maintainers):
     > Label smoothing actually starts to penalize the model if it gets
     > very confident about a given choice.
    """
    d = x + 3 * 1
    # predict = torch.FloatTensor([[0, x / d, 1 / d, 1 / d, 1 / d]])
    predict = torch.FloatTensor([[1.0e-10, x / d, 1 / d, 1 / d, 1 / d]])  # <-- workaround

    #   >>> crit(predict.log(), torch.LongTensor([1])).data 
    #   Out: tensor(nan)       # if torch.__version__ == 2.1 
    #   Out: tensor(0.9514)  # if torch.__version__ == 1.11
    return crit(predict.log(), torch.LongTensor([1])).data

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions