Skip to content

Explicit Cast Needed in LayerNorm example for MultiOutput Path #344

@PaulZhang12

Description

@PaulZhang12

In #170, an explicit cast was required for specifying the accumulator acc = x[tile_m, :].to(torch.float32). Without this cast, the kernel failed with an assertion failure in the multi output code path, even though it should be valid Helion logic.

What happens is that with the torch.float32 cast, we have 4 buffers, with the last buffer (buf4) being the variance buffer, which uses the mean (buf2). This works nicely as the nodes are connected in the graph directly
Image

However, in the case without the cast, the var_mean does an implicit conversion here: https://github.com/pytorch-labs/helion/blob/main/helion/_compiler/inductor_lowering_extra.py#L133, which means we now have buf4 being the mean (converted dtype to fp16), and buf3 being the variance. However, given that buf4 only reads from buf1, which is the mean in fp32, buf3 gets disconnected from the graph:

Image

From @jansel as a proposed solution: In the multi-output case the final node should return a tuple of the two ouputs. Maybe we need to introduce an extra node to combine the two outputs at the graph level? Like hl.multi_output(mean, variance) Then we could have a pass to delete those nodes and have the users read directly from mean, variance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions