We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 62d0c3e commit 2996c9aCopy full SHA for 2996c9a
src/liger_kernel/ops/rms_norm.py
@@ -351,7 +351,7 @@ def _block_rms_norm_backward_kernel(
351
352
# calculate the gradient of W
353
if casting_mode == _CASTING_MODE_LLAMA:
354
- # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
356
else:
357
# here X_row is already in fp32 (see previous if block)
0 commit comments