Skip to content

Commit 236905b

Browse files
authored
Update layernorm to directly use TRT API (#624)
1 parent 3407959 commit 236905b

File tree

3 files changed

+81
-12
lines changed

3 files changed

+81
-12
lines changed

tripy/nvtripy/frontend/module/layernorm.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,41 @@
2424
from nvtripy.frontend.module.parameter import DefaultParameter
2525
from nvtripy.frontend.tensor import Tensor
2626

27+
from nvtripy.frontend.ops import utils as op_utils
28+
from nvtripy.utils import wrappers
29+
from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp
30+
31+
32+
@wrappers.interface(
33+
dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"},
34+
dtype_variables={"T1": ["float32", "float16", "bfloat16"]},
35+
)
36+
def layernorm(
37+
input: "nvtripy.Tensor",
38+
weight: "nvtripy.Tensor",
39+
bias: "nvtripy.Tensor",
40+
eps: float,
41+
) -> "nvtripy.Tensor":
42+
43+
normalized_shape = weight.shape
44+
D = len(normalized_shape)
45+
input_rank = input.rank
46+
47+
# Reshape weight and bias to match input rank for TensorRT normalization (expects [1, ...] + normalized_shape)
48+
if input_rank > D:
49+
from nvtripy.frontend.ops.reshape import reshape
50+
51+
broadcast_shape = (1,) * (input_rank - D) + normalized_shape
52+
weight = reshape(weight, broadcast_shape)
53+
bias = reshape(bias, broadcast_shape)
54+
55+
return op_utils.create_op(
56+
LayerNormOp,
57+
[input, weight, bias],
58+
normalized_shape=normalized_shape,
59+
eps=eps,
60+
)
61+
2762

2863
@export.public_api(document_under="operations/modules")
2964
@dataclass
@@ -109,14 +144,4 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor":
109144
Returns:
110145
A tensor of the same shape as the input.
111146
"""
112-
from nvtripy.frontend.ops.reduce.mean import mean
113-
from nvtripy.frontend.ops.reduce.var import var
114-
from nvtripy.frontend.ops.unary.rsqrt import rsqrt
115-
116-
# The mean and the variance are computed over the last D dimensions
117-
D = len(self.normalized_shape)
118-
reduce_dims = tuple(-i for i in range(D, 0, -1))
119-
mean_val = mean(x, dim=reduce_dims, keepdim=True)
120-
var_val = var(x, dim=reduce_dims, keepdim=True, correction=0) + self.eps
121-
x = (x - mean_val) * rsqrt(var_val)
122-
return self.weight * x + self.bias
147+
return layernorm(x, self.weight, self.bias, self.eps)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from dataclasses import dataclass
19+
from typing import Sequence, Tuple
20+
import nvtripy.trace.ops.utils as op_utils
21+
from nvtripy.trace.ops.base import TraceOp
22+
from mlir_tensorrt.compiler.dialects import tensorrt
23+
24+
from mlir_tensorrt.compiler import ir
25+
26+
27+
@dataclass(repr=False)
28+
class LayerNorm(TraceOp):
29+
normalized_shape: Sequence[int]
30+
eps: float = 1e-5
31+
32+
infer_rank = op_utils.InferRankPolicies.same_as_input()
33+
34+
def infer_dtypes(self):
35+
self.outputs[0].dtype = self.inputs[0].dtype
36+
37+
def to_mlir(self, inputs, outputs):
38+
rank = outputs[0].rank
39+
D = len(self.normalized_shape)
40+
axis = ir.DenseI64ArrayAttr.get(list(range(rank - D, rank)))
41+
42+
return [tensorrt.normalization(inputs[0], inputs[1], inputs[2], axis=axis, eps=self.eps, num_groups=1)]

tripy/tests/frontend/module/test_layernorm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,7 @@ def test_layernorm_improper_dimensions(self):
2727
tp_layernorm.bias = tp.ones((2, 2))
2828

2929
x = tp.ones((5, 5, 5))
30-
with helper.raises(tp.TripyException, match="broadcast dimensions must be conformable"):
30+
with helper.raises(
31+
tp.TripyException, match="The normalization scale is not broadcast-compatible with the input at dimension 1"
32+
):
3133
tp_layernorm(x).eval()

0 commit comments

Comments
 (0)