Skip to content

Commit 64aed72

Browse files
Evaluates parameters in GroupNorm's instancenorm
1 parent 0e51f4d commit 64aed72

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tripy/nvtripy/frontend/module/groupnorm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from nvtripy import export, utils
2121
from nvtripy.common import datatype
22+
from nvtripy.common.device import device
2223
from nvtripy.common.exception import raise_error
2324
from nvtripy.frontend.module.instancenorm import InstanceNorm
2425
from nvtripy.frontend.module.module import Module
@@ -90,6 +91,7 @@ def __init__(
9091
assert np_out.shape == torch_out.shape
9192
assert np.allclose(np_out, torch_out)
9293
"""
94+
from nvtripy.frontend.ops.copy import copy
9395
from nvtripy.frontend.ops.ones import ones
9496
from nvtripy.frontend.ops.zeros import zeros
9597

@@ -117,8 +119,12 @@ def __init__(self, instance_norm):
117119

118120
self.impl = Hide(InstanceNorm(self.num_groups, dtype=self.dtype, eps=self.eps))
119121
# Bypass shape checks:
120-
object.__setattr__(self.impl.instance_norm, "weight", ones((self.num_groups,), dtype=self.dtype))
121-
object.__setattr__(self.impl.instance_norm, "bias", zeros((self.num_groups,), dtype=self.dtype))
122+
object.__setattr__(
123+
self.impl.instance_norm, "weight", copy(ones((self.num_groups,), dtype=self.dtype), device=device("cpu"))
124+
)
125+
object.__setattr__(
126+
self.impl.instance_norm, "bias", copy(zeros((self.num_groups,), dtype=self.dtype), device=device("cpu"))
127+
)
122128

123129
def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor":
124130
r"""

0 commit comments

Comments
 (0)