-
Notifications
You must be signed in to change notification settings - Fork 15
Description
In #170, an explicit cast was required for specifying the accumulator acc = x[tile_m, :].to(torch.float32)
. Without this cast, the kernel failed with an assertion failure in the multi output code path, even though it should be valid Helion logic.
What happens is that with the torch.float32 cast, we have 4 buffers, with the last buffer (buf4) being the variance buffer, which uses the mean (buf2). This works nicely as the nodes are connected in the graph directly
However, in the case without the cast, the var_mean does an implicit conversion here: https://github.com/pytorch-labs/helion/blob/main/helion/_compiler/inductor_lowering_extra.py#L133, which means we now have buf4 being the mean (converted dtype to fp16), and buf3 being the variance. However, given that buf4 only reads from buf1, which is the mean in fp32, buf3 gets disconnected from the graph:

From @jansel as a proposed solution: In the multi-output case the final node should return a tuple of the two ouputs. Maybe we need to introduce an extra node to combine the two outputs at the graph level? Like hl.multi_output(mean, variance)
Then we could have a pass to delete those nodes and have the users read directly from mean, variance.