Description
F.cross_entropy(logits, targets).backward() fails when logits have shape (N, C) and targets have shape (N,). The backward pass hits an add op with mismatched shapes (N, C) vs (N,).
Repro
import candle as torch
import candle.nn.functional as F
logits = torch.randn(64, 10, requires_grad=True)
targets = torch.randint(0, 10, (64,))
loss = F.cross_entropy(logits, targets)
loss.backward() # RuntimeError
Error
RuntimeError: operands could not be broadcast together with shapes (64,10) (64,) [op=add, device=cpu]
Stack trace (key frames)
File "src/candle/_generated/functions.py", line 10996, in backward
grad_self = redispatch("add", keyset, ...)
...
RuntimeError: operands could not be broadcast together with shapes (64,10) (64,)
Root cause
The generated autograd backward for an add node in the cross_entropy computation graph receives gradient tensors of incompatible shapes. This suggests a missing unsqueeze or expand in the cross_entropy / nll_loss backward chain, where a (N,) gradient should be expanded to (N, C) or reshaped before being added to a (N, C) tensor.
Context
Discovered running the nn_tutorial.py (What is torch.nn really?) tutorial from PyTorch docs. The tutorial trains a logistic regression model with F.cross_entropy on MNIST (batch_size=64, num_classes=10).
Description
F.cross_entropy(logits, targets).backward()fails when logits have shape(N, C)and targets have shape(N,). The backward pass hits anaddop with mismatched shapes(N, C)vs(N,).Repro
Error
Stack trace (key frames)
Root cause
The generated autograd backward for an
addnode in the cross_entropy computation graph receives gradient tensors of incompatible shapes. This suggests a missingunsqueezeorexpandin the cross_entropy / nll_loss backward chain, where a(N,)gradient should be expanded to(N, C)or reshaped before being added to a(N, C)tensor.Context
Discovered running the nn_tutorial.py (What is torch.nn really?) tutorial from PyTorch docs. The tutorial trains a logistic regression model with
F.cross_entropyon MNIST (batch_size=64, num_classes=10).