|
19 | 19 |
|
20 | 20 | from nvtripy import export, utils |
21 | 21 | from nvtripy.common import datatype |
| 22 | +from nvtripy.common.device import device |
22 | 23 | from nvtripy.common.exception import raise_error |
23 | 24 | from nvtripy.frontend.module.instancenorm import InstanceNorm |
24 | 25 | from nvtripy.frontend.module.module import Module |
@@ -90,6 +91,7 @@ def __init__( |
90 | 91 | assert np_out.shape == torch_out.shape |
91 | 92 | assert np.allclose(np_out, torch_out) |
92 | 93 | """ |
| 94 | + from nvtripy.frontend.ops.copy import copy |
93 | 95 | from nvtripy.frontend.ops.ones import ones |
94 | 96 | from nvtripy.frontend.ops.zeros import zeros |
95 | 97 |
|
@@ -117,8 +119,12 @@ def __init__(self, instance_norm): |
117 | 119 |
|
118 | 120 | self.impl = Hide(InstanceNorm(self.num_groups, dtype=self.dtype, eps=self.eps)) |
119 | 121 | # Bypass shape checks: |
120 | | - object.__setattr__(self.impl.instance_norm, "weight", ones((self.num_groups,), dtype=self.dtype)) |
121 | | - object.__setattr__(self.impl.instance_norm, "bias", zeros((self.num_groups,), dtype=self.dtype)) |
| 122 | + object.__setattr__( |
| 123 | + self.impl.instance_norm, "weight", copy(ones((self.num_groups,), dtype=self.dtype), device=device("cpu")) |
| 124 | + ) |
| 125 | + object.__setattr__( |
| 126 | + self.impl.instance_norm, "bias", copy(zeros((self.num_groups,), dtype=self.dtype), device=device("cpu")) |
| 127 | + ) |
122 | 128 |
|
123 | 129 | def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": |
124 | 130 | r""" |
|
0 commit comments