Skip to content

Add TorchTyping? #71

Open
Open
@GilesStrong

Description

@GilesStrong

Recently I attended the PyTorch Developer event and one of the posters advertised the torchtyping package which can be used to better type-hint tensors in the code, by showing the expected dimensions with string descriptions, e.g.:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

The dimensions of tensors can also be checked at runtime.

This is potentially useful for both developers & power-users, due to more detailed type-hinting and better diagnostics when using custom modules.

However, it is not compatible with MyPy static type-checking, meaning that TensorType must be treated as Any:

from torchtyping import TensorType  # type: ignore

Whilst not essential, replacing the Tensor type-hints in the codebase with TensorType could be useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentationenhancementNew feature or requestgood first issueGood for newcomerslow priorityShould be fixed eventually, but isn't urgent

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions