Open
Description
Recently I attended the PyTorch Developer event and one of the posters advertised the torchtyping package which can be used to better type-hint tensors in the code, by showing the expected dimensions with string descriptions, e.g.:
def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:
return x.unsqueeze(-1) * y.unsqueeze(-2)
The dimensions of tensors can also be checked at runtime.
This is potentially useful for both developers & power-users, due to more detailed type-hinting and better diagnostics when using custom modules.
However, it is not compatible with MyPy static type-checking, meaning that TensorType
must be treated as Any
:
from torchtyping import TensorType # type: ignore
Whilst not essential, replacing the Tensor
type-hints in the codebase with TensorType
could be useful.