|
24 | 24 | from nvtripy.frontend.module.parameter import DefaultParameter |
25 | 25 | from nvtripy.frontend.tensor import Tensor |
26 | 26 |
|
| 27 | +from nvtripy.frontend.ops import utils as op_utils |
| 28 | +from nvtripy.utils import wrappers |
| 29 | +from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp |
| 30 | + |
| 31 | + |
| 32 | +@wrappers.interface( |
| 33 | + dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, |
| 34 | + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, |
| 35 | +) |
| 36 | +def layernorm( |
| 37 | + input: "nvtripy.Tensor", |
| 38 | + weight: "nvtripy.Tensor", |
| 39 | + bias: "nvtripy.Tensor", |
| 40 | + eps: float, |
| 41 | +) -> "nvtripy.Tensor": |
| 42 | + |
| 43 | + normalized_shape = weight.shape |
| 44 | + D = len(normalized_shape) |
| 45 | + input_rank = input.rank |
| 46 | + |
| 47 | + # Reshape weight and bias to match input rank for TensorRT normalization (expects [1, ...] + normalized_shape) |
| 48 | + if input_rank > D: |
| 49 | + from nvtripy.frontend.ops.reshape import reshape |
| 50 | + |
| 51 | + broadcast_shape = (1,) * (input_rank - D) + normalized_shape |
| 52 | + weight = reshape(weight, broadcast_shape) |
| 53 | + bias = reshape(bias, broadcast_shape) |
| 54 | + |
| 55 | + return op_utils.create_op( |
| 56 | + LayerNormOp, |
| 57 | + [input, weight, bias], |
| 58 | + normalized_shape=normalized_shape, |
| 59 | + eps=eps, |
| 60 | + ) |
| 61 | + |
27 | 62 |
|
28 | 63 | @export.public_api(document_under="operations/modules") |
29 | 64 | @dataclass |
@@ -109,14 +144,4 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": |
109 | 144 | Returns: |
110 | 145 | A tensor of the same shape as the input. |
111 | 146 | """ |
112 | | - from nvtripy.frontend.ops.reduce.mean import mean |
113 | | - from nvtripy.frontend.ops.reduce.var import var |
114 | | - from nvtripy.frontend.ops.unary.rsqrt import rsqrt |
115 | | - |
116 | | - # The mean and the variance are computed over the last D dimensions |
117 | | - D = len(self.normalized_shape) |
118 | | - reduce_dims = tuple(-i for i in range(D, 0, -1)) |
119 | | - mean_val = mean(x, dim=reduce_dims, keepdim=True) |
120 | | - var_val = var(x, dim=reduce_dims, keepdim=True, correction=0) + self.eps |
121 | | - x = (x - mean_val) * rsqrt(var_val) |
122 | | - return self.weight * x + self.bias |
| 147 | + return layernorm(x, self.weight, self.bias, self.eps) |
0 commit comments