Skip to content

F.cross_entropy backward fails with broadcast shape mismatch #323

Description

@lvyufeng

Description

F.cross_entropy(logits, targets).backward() fails when logits have shape (N, C) and targets have shape (N,). The backward pass hits an add op with mismatched shapes (N, C) vs (N,).

Repro

import candle as torch
import candle.nn.functional as F

logits = torch.randn(64, 10, requires_grad=True)
targets = torch.randint(0, 10, (64,))
loss = F.cross_entropy(logits, targets)
loss.backward()   # RuntimeError

Error

RuntimeError: operands could not be broadcast together with shapes (64,10) (64,) [op=add, device=cpu]

Stack trace (key frames)

  File "src/candle/_generated/functions.py", line 10996, in backward
    grad_self = redispatch("add", keyset, ...)
  ...
RuntimeError: operands could not be broadcast together with shapes (64,10) (64,)

Root cause

The generated autograd backward for an add node in the cross_entropy computation graph receives gradient tensors of incompatible shapes. This suggests a missing unsqueeze or expand in the cross_entropy / nll_loss backward chain, where a (N,) gradient should be expanded to (N, C) or reshaped before being added to a (N, C) tensor.

Context

Discovered running the nn_tutorial.py (What is torch.nn really?) tutorial from PyTorch docs. The tutorial trains a logistic regression model with F.cross_entropy on MNIST (batch_size=64, num_classes=10).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions