diff --git a/DeDoDe/encoder.py b/DeDoDe/encoder.py index 91880e7..55f99f4 100644 --- a/DeDoDe/encoder.py +++ b/DeDoDe/encoder.py @@ -6,7 +6,10 @@ class VGG19(nn.Module): def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: super().__init__() - self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + vgg_kwargs = {} + if not pretrained: + vgg_kwargs["weights"] = None + self.layers = nn.ModuleList(tvm.vgg19_bn(**vgg_kwargs).features[:40]) # Maxpool layers: 6, 13, 26, 39 self.amp = amp self.amp_dtype = amp_dtype @@ -25,12 +28,15 @@ def forward(self, x, **kwargs): class VGG(nn.Module): def __init__(self, size = "19", pretrained=False, amp = False, amp_dtype = torch.float16) -> None: super().__init__() + vgg_kwargs = {} + if not pretrained: + vgg_kwargs["weights"] = None if size == "11": - self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22]) + self.layers = nn.ModuleList(tvm.vgg11_bn(**vgg_kwargs).features[:22]) elif size == "13": - self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28]) + self.layers = nn.ModuleList(tvm.vgg13_bn(**vgg_kwargs).features[:28]) elif size == "19": - self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.layers = nn.ModuleList(tvm.vgg19_bn(**vgg_kwargs).features[:40]) # Maxpool layers: 6, 13, 26, 39 self.amp = amp self.amp_dtype = amp_dtype diff --git a/DeDoDe/transformer/layers/attention.py b/DeDoDe/transformer/layers/attention.py index 1f9b0c9..aa4bf12 100644 --- a/DeDoDe/transformer/layers/attention.py +++ b/DeDoDe/transformer/layers/attention.py @@ -9,6 +9,7 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging +import warnings from torch import Tensor from torch import nn @@ -22,7 +23,7 @@ XFORMERS_AVAILABLE = True except ImportError: - logger.warning("xFormers not available") + warnings.warn("xFormers not available") XFORMERS_AVAILABLE = False diff --git a/DeDoDe/transformer/layers/block.py b/DeDoDe/transformer/layers/block.py index 25488f5..3e350fb 100644 --- a/DeDoDe/transformer/layers/block.py +++ b/DeDoDe/transformer/layers/block.py @@ -10,6 +10,7 @@ import logging from typing import Callable, List, Any, Tuple, Dict +import warnings import torch from torch import nn, Tensor @@ -29,7 +30,7 @@ XFORMERS_AVAILABLE = True except ImportError: - logger.warning("xFormers not available") + warnings.warn("xFormers not available") XFORMERS_AVAILABLE = False diff --git a/DeDoDe/utils.py b/DeDoDe/utils.py index 9475dc8..91c4376 100644 --- a/DeDoDe/utils.py +++ b/DeDoDe/utils.py @@ -69,12 +69,13 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): def get_grid(B,H,W, device = get_best_device()): x1_n = torch.meshgrid( - *[ - torch.linspace( - -1 + 1 / n, 1 - 1 / n, n, device=device - ) - for n in (B, H, W) - ] + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=device + ) + for n in (B, H, W) + ], + indexing = "ij", ) x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) return x1_n @@ -217,7 +218,8 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili -1 + 1 / n, 1 - 1 / n, n, device=depth1.device ) for n in (B, H, W) - ] + ], + indexing = "ij", ) x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) mask, x2 = warp_kpts( @@ -670,7 +672,7 @@ def homog_transform(Homog, x): return y def get_homog_warp(Homog, H, W, device = get_best_device()): - grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device)) + grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device), indexing = "ij") x_A = torch.stack((grid[1], grid[0]), dim = -1)[None] x_A_to_B = homog_transform(Homog, x_A)