-
Notifications
You must be signed in to change notification settings - Fork 610
Description
I tried outputting the output of each RSTB block and found that it wasn't scaled correctly, resulting in severe numerical explosion. At float16 precision, it even caused an overflow error; it only worked correctly after changing to float32. I'm using the officially released weights 001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth, but everything works fine in the test metrics. Is this inherent to the model? I can't add extra numerical operations inside the loop because this would disrupt the learned distribution.
This is an interesting issue. I just found out that the reason behind fp16 instability is that x (shown below) is getting out of the range of fp16 during inference. Try to debug this piece of code - in main
class SwinIR. If you look atx.std()during the iteration throughoutself.layers, you will find thatstdis increasing exponentially. After 6 iterations, it is beyond the range of fp16 and gets converted into NANs. This is then clipped to be 1.0 at the end, and hence your black screen. @JingyunLiang , should this be the case here? I see that even when using fp32, thex.std()is increasing exponentially during looping.def forward_features(self, x): x_size = (x.shape[2], x.shape[3]) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x, x_size) # Here! It explodes on fp16 after reaching `x.std() > 52000` x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) return x