diff --git a/unimatch/attention.py b/unimatch/attention.py
index a10f758..240d5c7 100755
--- a/unimatch/attention.py
+++ b/unimatch/attention.py
@@ -1,166 +1,180 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from typing import Optional
 
 from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
 
+class single_head_full_attention(nn.Module):
+    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+        # q, k, v: [B, L, C]
+        assert q.dim() == k.dim() == v.dim() == 3
 
-def single_head_full_attention(q, k, v):
-    # q, k, v: [B, L, C]
-    assert q.dim() == k.dim() == v.dim() == 3
+        scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5)  # [B, L, L]
+        attn = torch.softmax(scores, dim=2)  # [B, L, L]
+        out = torch.matmul(attn, v)  # [B, L, C]
 
-    scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5)  # [B, L, L]
-    attn = torch.softmax(scores, dim=2)  # [B, L, L]
-    out = torch.matmul(attn, v)  # [B, L, C]
-
-    return out
+        return out
 
+class single_head_full_attention_1d(nn.Module):
+    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+                h: Optional[int] = None,
+                w: Optional[int] = None
+                ) -> torch.Tensor:
+        # q, k, v: [B, L, C]
+        assert h is not None and w is not None
+        assert q.size(1) == h * w
 
-def single_head_full_attention_1d(q, k, v,
-                                  h=None,
-                                  w=None,
-                                  ):
-    # q, k, v: [B, L, C]
+        b, _, c = q.size()
 
-    assert h is not None and w is not None
-    assert q.size(1) == h * w
+        q = q.view(b, h, w, c)  # [B, H, W, C]
+        k = k.view(b, h, w, c)
+        v = v.view(b, h, w, c)
 
-    b, _, c = q.size()
+        scale_factor = c ** 0.5
 
-    q = q.view(b, h, w, c)  # [B, H, W, C]
-    k = k.view(b, h, w, c)
-    v = v.view(b, h, w, c)
+        scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor  # [B, H, W, W]
 
-    scale_factor = c ** 0.5
+        attn = torch.softmax(scores, dim=-1)
 
-    scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor  # [B, H, W, W]
+        out = torch.matmul(attn, v).view(b, -1, c)  # [B, H*W, C]
 
-    attn = torch.softmax(scores, dim=-1)
+        return out
 
-    out = torch.matmul(attn, v).view(b, -1, c)  # [B, H*W, C]
+class single_head_split_window_attention(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.split_feature = split_feature()
+        self.merge_splits = merge_splits()
 
-    return out
+    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+                num_splits: int = 1,
+                with_shift: bool = False,
+                h: Optional[int] = None,
+                w: Optional[int] = None,
+                attn_mask: Optional[torch.Tensor] = None,
+                ) -> torch.Tensor:
+        # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+        # q, k, v: [B, L, C]
+        assert q.dim() == k.dim() == v.dim() == 3
 
+        assert h is not None and w is not None
+        assert q.size(1) == h * w
 
-def single_head_split_window_attention(q, k, v,
-                                       num_splits=1,
-                                       with_shift=False,
-                                       h=None,
-                                       w=None,
-                                       attn_mask=None,
-                                       ):
-    # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
-    # q, k, v: [B, L, C]
-    assert q.dim() == k.dim() == v.dim() == 3
+        b, _, c = q.size()
 
-    assert h is not None and w is not None
-    assert q.size(1) == h * w
+        b_new = b * num_splits * num_splits
 
-    b, _, c = q.size()
+        window_size_h = int(h // num_splits)
+        window_size_w = int(w // num_splits)
 
-    b_new = b * num_splits * num_splits
+        q = q.view(b, h, w, c)  # [B, H, W, C]
+        k = k.view(b, h, w, c)
+        v = v.view(b, h, w, c)
 
-    window_size_h = h // num_splits
-    window_size_w = w // num_splits
+        scale_factor = c ** 0.5
 
-    q = q.view(b, h, w, c)  # [B, H, W, C]
-    k = k.view(b, h, w, c)
-    v = v.view(b, h, w, c)
+        shift_size_w = 0
+        shift_size_h = 0
 
-    scale_factor = c ** 0.5
+        if with_shift:
+            assert attn_mask is not None  # compute once
+            shift_size_h = window_size_h // 2
+            shift_size_w = window_size_w // 2
 
-    if with_shift:
-        assert attn_mask is not None  # compute once
-        shift_size_h = window_size_h // 2
-        shift_size_w = window_size_w // 2
+            q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+            k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+            v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
 
-        q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
-        k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
-        v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+        q = self.split_feature(q, num_splits=num_splits, channel_last=True)  # [B*K*K, H/K, W/K, C]
+        k = self.split_feature(k, num_splits=num_splits, channel_last=True)
+        v = self.split_feature(v, num_splits=num_splits, channel_last=True)
 
-    q = split_feature(q, num_splits=num_splits, channel_last=True)  # [B*K*K, H/K, W/K, C]
-    k = split_feature(k, num_splits=num_splits, channel_last=True)
-    v = split_feature(v, num_splits=num_splits, channel_last=True)
+        scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
+                            ) / scale_factor  # [B*K*K, H/K*W/K, H/K*W/K]
 
-    scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
-                          ) / scale_factor  # [B*K*K, H/K*W/K, H/K*W/K]
+        if with_shift and attn_mask is not None:
+            scores += attn_mask.repeat(b, 1, 1)
 
-    if with_shift:
-        scores += attn_mask.repeat(b, 1, 1)
+        attn = torch.softmax(scores, dim=-1)
 
-    attn = torch.softmax(scores, dim=-1)
+        out = torch.matmul(attn, v.view(b_new, -1, c))  # [B*K*K, H/K*W/K, C]
 
-    out = torch.matmul(attn, v.view(b_new, -1, c))  # [B*K*K, H/K*W/K, C]
+        out = self.merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
+                        num_splits=num_splits, channel_last=True)  # [B, H, W, C]
 
-    out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
-                       num_splits=num_splits, channel_last=True)  # [B, H, W, C]
+        # shift back
+        if with_shift:
+            out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
 
-    # shift back
-    if with_shift:
-        out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
+        out = out.view(b, -1, c)
 
-    out = out.view(b, -1, c)
+        return out
 
-    return out
+class single_head_split_window_attention_1d(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.split_feature_1d = split_feature_1d()
+        self.merge_splits_1d = merge_splits_1d()
 
+    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+                num_splits: int  = 1,
+                with_shift: bool = False,
+                h: Optional[int] = None,
+                w: Optional[int] = None,
+                attn_mask: Optional[torch.Tensor] = None,
+                ) -> torch.Tensor:
+        # q, k, v: [B, L, C]
 
-def single_head_split_window_attention_1d(q, k, v,
-                                          relative_position_bias=None,
-                                          num_splits=1,
-                                          with_shift=False,
-                                          h=None,
-                                          w=None,
-                                          attn_mask=None,
-                                          ):
-    # q, k, v: [B, L, C]
+        assert h is not None and w is not None
+        assert q.size(1) == h * w
 
-    assert h is not None and w is not None
-    assert q.size(1) == h * w
+        b, _, c = q.size()
 
-    b, _, c = q.size()
+        b_new = b * num_splits * h
 
-    b_new = b * num_splits * h
+        window_size_w = w // num_splits
 
-    window_size_w = w // num_splits
+        q = q.view(b * h, w, c)  # [B*H, W, C]
+        k = k.view(b * h, w, c)
+        v = v.view(b * h, w, c)
 
-    q = q.view(b * h, w, c)  # [B*H, W, C]
-    k = k.view(b * h, w, c)
-    v = v.view(b * h, w, c)
+        scale_factor = c ** 0.5
 
-    scale_factor = c ** 0.5
+        shift_size_w = 0
 
-    if with_shift:
-        assert attn_mask is not None  # compute once
-        shift_size_w = window_size_w // 2
+        if with_shift:
+            assert attn_mask is not None  # compute once
+            shift_size_w = window_size_w // 2
 
-        q = torch.roll(q, shifts=-shift_size_w, dims=1)
-        k = torch.roll(k, shifts=-shift_size_w, dims=1)
-        v = torch.roll(v, shifts=-shift_size_w, dims=1)
+            q = torch.roll(q, shifts=-shift_size_w, dims=1)
+            k = torch.roll(k, shifts=-shift_size_w, dims=1)
+            v = torch.roll(v, shifts=-shift_size_w, dims=1)
 
-    q = split_feature_1d(q, num_splits=num_splits)  # [B*H*K, W/K, C]
-    k = split_feature_1d(k, num_splits=num_splits)
-    v = split_feature_1d(v, num_splits=num_splits)
+        q = self.split_feature_1d(q, num_splits=num_splits)  # [B*H*K, W/K, C]
+        k = self.split_feature_1d(k, num_splits=num_splits)
+        v = self.split_feature_1d(v, num_splits=num_splits)
 
-    scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
-                          ) / scale_factor  # [B*H*K, W/K, W/K]
+        scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
+                            ) / scale_factor  # [B*H*K, W/K, W/K]
 
-    if with_shift:
-        # attn_mask: [K, W/K, W/K]
-        scores += attn_mask.repeat(b * h, 1, 1)  # [B*H*K, W/K, W/K]
+        if with_shift and attn_mask is not None:
+            # attn_mask: [K, W/K, W/K]
+            scores += attn_mask.repeat(b * h, 1, 1)  # [B*H*K, W/K, W/K]
 
-    attn = torch.softmax(scores, dim=-1)
+        attn = torch.softmax(scores, dim=-1)
 
-    out = torch.matmul(attn, v.view(b_new, -1, c))  # [B*H*K, W/K, C]
+        out = torch.matmul(attn, v.view(b_new, -1, c))  # [B*H*K, W/K, C]
 
-    out = merge_splits_1d(out, h, num_splits=num_splits)  # [B, H, W, C]
+        out = self.merge_splits_1d(out, h, num_splits=num_splits)  # [B, H, W, C]
 
-    # shift back
-    if with_shift:
-        out = torch.roll(out, shifts=shift_size_w, dims=2)
+        # shift back
+        if with_shift:
+            out = torch.roll(out, shifts=shift_size_w, dims=2)
 
-    out = out.view(b, -1, c)
+        out = out.view(b, -1, c)
 
-    return out
+        return out
 
 
 class SelfAttnPropagation(nn.Module):
@@ -169,9 +183,7 @@ class SelfAttnPropagation(nn.Module):
     query: feature0, key: feature0, value: flow
     """
 
-    def __init__(self, in_channels,
-                 **kwargs,
-                 ):
+    def __init__(self, in_channels: int):
         super(SelfAttnPropagation, self).__init__()
 
         self.q_proj = nn.Linear(in_channels, in_channels)
@@ -181,11 +193,10 @@ def __init__(self, in_channels,
             if p.dim() > 1:
                 nn.init.xavier_uniform_(p)
 
-    def forward(self, feature0, flow,
-                local_window_attn=False,
-                local_window_radius=1,
-                **kwargs,
-                ):
+    def forward(self, feature0: torch.Tensor, flow: torch.Tensor,
+                local_window_attn: bool = False,
+                local_window_radius: int = 1
+                ) -> torch.Tensor:
         # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
         if local_window_attn:
             return self.forward_local_window_attn(feature0, flow,
@@ -214,9 +225,9 @@ def forward(self, feature0, flow,
 
         return out
 
-    def forward_local_window_attn(self, feature0, flow,
-                                  local_window_radius=1,
-                                  ):
+    def forward_local_window_attn(self, feature0: torch.Tensor, flow: torch.Tensor,
+                                  local_window_radius: int = 1
+                                  ) -> torch.Tensor:
         assert flow.size(1) == 2 or flow.size(1) == 1  # flow or disparity or depth
         assert local_window_radius > 0
 
@@ -227,21 +238,21 @@ def forward_local_window_attn(self, feature0, flow,
         feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
                                        ).reshape(b * h * w, 1, c)  # [B*H*W, 1, C]
 
-        kernel_size = 2 * local_window_radius + 1
+        kernel_size = int(2 * local_window_radius + 1)
 
         feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
 
-        feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
-                                   padding=local_window_radius)  # [B, C*(2R+1)^2), H*W]
+        feature0_window = F.unfold(feature0_proj, kernel_size=int(kernel_size),
+                                   padding=int(local_window_radius))  # [B, C*(2R+1)^2), H*W]
 
-        feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
-            0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2)  # [B*H*W, C, (2R+1)^2]
+        feature0_window = feature0_window.view(b, c, int(kernel_size ** 2), h, w).permute(
+            0, 3, 4, 1, 2).reshape(b * h * w, c, int(kernel_size ** 2))  # [B*H*W, C, (2R+1)^2]
 
-        flow_window = F.unfold(flow, kernel_size=kernel_size,
-                               padding=local_window_radius)  # [B, 2*(2R+1)^2), H*W]
+        flow_window = F.unfold(flow, kernel_size=int(kernel_size),
+                               padding=int(local_window_radius))  # [B, 2*(2R+1)^2), H*W]
 
-        flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
-            0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel)  # [B*H*W, (2R+1)^2, 2]
+        flow_window = flow_window.view(b, value_channel, int(kernel_size ** 2), h, w).permute(
+            0, 3, 4, 2, 1).reshape(b * h * w, int(kernel_size ** 2), value_channel)  # [B*H*W, (2R+1)^2, 2]
 
         scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5)  # [B*H*W, 1, (2R+1)^2]
 
diff --git a/unimatch/geometry.py b/unimatch/geometry.py
index 775a957..ddee53a 100755
--- a/unimatch/geometry.py
+++ b/unimatch/geometry.py
@@ -1,9 +1,9 @@
 import torch
 import torch.nn.functional as F
+from typing import Optional, Tuple
 
-
-def coords_grid(b, h, w, homogeneous=False, device=None):
-    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]
+def coords_grid(b: int, h: int, w: int, homogeneous: bool = False, device: Optional[torch.device] = None):
+    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')  # [H, W]
 
     stacks = [x, y]
 
@@ -21,24 +21,28 @@ def coords_grid(b, h, w, homogeneous=False, device=None):
     return grid
 
 
-def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
+def generate_window_grid(h_min: int, h_max: int, w_min: int, w_max: int, len_h: int, len_w: int, device: Optional[torch.device] = None):
     assert device is not None
 
     x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
                            torch.linspace(h_min, h_max, len_h, device=device)],
-                          )
+                          indexing = 'ij')
     grid = torch.stack((x, y), -1).transpose(0, 1).float()  # [H, W, 2]
 
     return grid
 
 
-def normalize_coords(coords, h, w):
+def normalize_coords(coords: torch.Tensor, h: int, w: int):
     # coords: [B, H, W, 2]
-    c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
+    c = torch.tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
     return (coords - c) / c  # [-1, 1]
 
 
-def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
+def bilinear_sample(img: torch.Tensor, sample_coords: torch.Tensor,
+                    mode: str = 'bilinear',
+                    padding_mode: str = 'zeros',
+                    return_mask: bool = False
+                    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     # img: [B, C, H, W]
     # sample_coords: [B, 2, H, W] in image scale
     if sample_coords.size(1) != 2:  # [B, H, W, 2]
@@ -59,10 +63,13 @@ def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', r
 
         return img, mask
 
-    return img
+    return img, None
 
 
-def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
+def flow_warp(feature:torch.Tensor, flow: torch.Tensor,
+              mask: bool = False,
+              padding_mode: str = 'zeros'
+              ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     b, c, h, w = feature.size()
     assert flow.size(1) == 2
 
@@ -72,9 +79,9 @@ def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
                            return_mask=mask)
 
 
-def forward_backward_consistency_check(fwd_flow, bwd_flow,
-                                       alpha=0.01,
-                                       beta=0.5
+def forward_backward_consistency_check(fwd_flow: torch.Tensor, bwd_flow: torch.Tensor,
+                                       alpha: float = 0.01,
+                                       beta: float = 0.5
                                        ):
     # fwd_flow, bwd_flow: [B, 2, H, W]
     # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
@@ -82,8 +89,8 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow,
     assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
     flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1)  # [B, H, W]
 
-    warped_bwd_flow = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]
-    warped_fwd_flow = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]
+    warped_bwd_flow, _ = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]
+    warped_fwd_flow, _ = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]
 
     diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1)  # [B, H, W]
     diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
@@ -96,7 +103,7 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow,
     return fwd_occ, bwd_occ
 
 
-def back_project(depth, intrinsics):
+def back_project(depth: torch.Tensor, intrinsics: torch.Tensor):
     # Back project 2D pixel coords to 3D points
     # depth: [B, H, W]
     # intrinsics: [B, 3, 3]
@@ -110,7 +117,10 @@ def back_project(depth, intrinsics):
     return points
 
 
-def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
+def camera_transform(points_ref: torch.Tensor,
+                     extrinsics_ref: Optional[torch.Tensor] = None,
+                     extrinsics_tgt: Optional[torch.Tensor] = None,
+                     extrinsics_rel: Optional[torch.Tensor] = None):
     # Transform 3D points from reference camera to target camera
     # points_ref: [B, 3, H, W]
     # extrinsics_ref: [B, 4, 4]
@@ -119,6 +129,8 @@ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extri
     b, _, h, w = points_ref.shape
 
     if extrinsics_rel is None:
+        assert extrinsics_tgt is not None
+        assert extrinsics_ref is not None
         extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref))  # [B, 4, 4]
 
     points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
@@ -129,7 +141,9 @@ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extri
     return points_tgt
 
 
-def reproject(points_tgt, intrinsics, return_mask=False):
+def reproject(points_tgt: torch.Tensor, intrinsics: torch.Tensor,
+              return_mask: bool = False
+              ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     # reproject to target view
     # points_tgt: [B, 3, H, W]
     # intrinsics: [B, 3, 3]
@@ -151,11 +165,15 @@ def reproject(points_tgt, intrinsics, return_mask=False):
 
         return pixel_coords, mask
 
-    return pixel_coords
+    return pixel_coords, None
 
 
-def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
-                     return_mask=False):
+def reproject_coords(depth_ref: torch.Tensor, intrinsics: torch.Tensor,
+                     extrinsics_ref: Optional[torch.Tensor] = None,
+                     extrinsics_tgt: Optional[torch.Tensor] = None,
+                     extrinsics_rel: Optional[torch.Tensor] = None,
+                     return_mask: bool = False
+                     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     # Compute reprojection sample coords
     points_ref = back_project(depth_ref, intrinsics)  # [B, 3, H, W]
     points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
@@ -166,15 +184,18 @@ def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=
 
         return reproj_coords, mask
 
-    reproj_coords = reproject(points_tgt, intrinsics,
+    reproj_coords, _ = reproject(points_tgt, intrinsics,
                               return_mask=return_mask)  # [B, 2, H, W] in image scale
 
-    return reproj_coords
+    return reproj_coords, None
 
 
-def compute_flow_with_depth_pose(depth_ref, intrinsics,
-                                 extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
-                                 return_mask=False):
+def compute_flow_with_depth_pose(depth_ref: torch.Tensor, intrinsics: torch.Tensor,
+                                 extrinsics_ref: Optional[torch.Tensor] = None,
+                                 extrinsics_tgt: Optional[torch.Tensor] = None,
+                                 extrinsics_rel: Optional[torch.Tensor] = None,
+                                 return_mask: bool = False
+                                 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     b, h, w = depth_ref.shape
     coords_init = coords_grid(b, h, w, device=depth_ref.device)  # [B, 2, H, W]
 
@@ -186,10 +207,10 @@ def compute_flow_with_depth_pose(depth_ref, intrinsics,
 
         return rigid_flow, mask
 
-    reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
+    reproj_coords, _ = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
                                      extrinsics_rel=extrinsics_rel,
                                      return_mask=return_mask)  # [B, 2, H, W]
 
     rigid_flow = reproj_coords - coords_init
 
-    return rigid_flow
+    return rigid_flow, None
diff --git a/unimatch/matching.py b/unimatch/matching.py
index 6471025..52d676f 100755
--- a/unimatch/matching.py
+++ b/unimatch/matching.py
@@ -1,12 +1,13 @@
 import torch
 import torch.nn.functional as F
+from typing import Tuple
 
 from .geometry import coords_grid, generate_window_grid, normalize_coords
 
 
-def global_correlation_softmax(feature0, feature1,
-                               pred_bidir_flow=False,
-                               ):
+def global_correlation_softmax(feature0: torch.Tensor, feature1: torch.Tensor,
+                               pred_bidir_flow: bool = False,
+                               ) -> Tuple[torch.Tensor, torch.Tensor]:
     # global correlation
     b, c, h, w = feature0.shape
     feature0 = feature0.view(b, c, -1).permute(0, 2, 1)  # [B, H*W, C]
@@ -36,9 +37,9 @@ def global_correlation_softmax(feature0, feature1,
     return flow, prob
 
 
-def local_correlation_softmax(feature0, feature1, local_radius,
-                              padding_mode='zeros',
-                              ):
+def local_correlation_softmax(feature0: torch.Tensor, feature1: torch.Tensor, local_radius: int,
+                              padding_mode: str = 'zeros',
+                              ) -> Tuple[torch.Tensor, torch.Tensor]:
     b, c, h, w = feature0.size()
     coords_init = coords_grid(b, h, w).to(feature0.device)  # [B, 2, H, W]
     coords = coords_init.view(b, 2, -1).permute(0, 2, 1)  # [B, H*W, 2]
@@ -83,12 +84,12 @@ def local_correlation_softmax(feature0, feature1, local_radius,
     return flow, match_prob
 
 
-def local_correlation_with_flow(feature0, feature1,
-                                flow,
-                                local_radius,
-                                padding_mode='zeros',
-                                dilation=1,
-                                ):
+def local_correlation_with_flow(feature0: torch.Tensor, feature1: torch.Tensor,
+                                flow: torch.Tensor,
+                                local_radius: int,
+                                padding_mode: str = 'zeros',
+                                dilation: int = 1,
+                                ) -> torch.Tensor:
     b, c, h, w = feature0.size()
     coords_init = coords_grid(b, h, w).to(feature0.device)  # [B, 2, H, W]
     coords = coords_init.view(b, 2, -1).permute(0, 2, 1)  # [B, H*W, 2]
@@ -123,8 +124,7 @@ def local_correlation_with_flow(feature0, feature1,
     return corr
 
 
-def global_correlation_softmax_stereo(feature0, feature1,
-                                      ):
+def global_correlation_softmax_stereo(feature0: torch.Tensor, feature1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     # global correlation on horizontal direction
     b, c, h, w = feature0.shape
 
@@ -151,8 +151,7 @@ def global_correlation_softmax_stereo(feature0, feature1,
     return disparity.unsqueeze(1), prob  # feature resolution
 
 
-def local_correlation_softmax_stereo(feature0, feature1, local_radius,
-                                     ):
+def local_correlation_softmax_stereo(feature0: torch.Tensor, feature1: torch.Tensor, local_radius: int) -> Tuple[torch.Tensor, torch.Tensor]:
     b, c, h, w = feature0.size()
     coords_init = coords_grid(b, h, w).to(feature0.device)  # [B, 2, H, W]
     coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous()  # [B, H*W, 2]
@@ -200,13 +199,13 @@ def local_correlation_softmax_stereo(feature0, feature1, local_radius,
     return flow_x, match_prob
 
 
-def correlation_softmax_depth(feature0, feature1,
-                              intrinsics,
-                              pose,
-                              depth_candidates,
-                              depth_from_argmax=False,
-                              pred_bidir_depth=False,
-                              ):
+def correlation_softmax_depth(feature0: torch.Tensor, feature1: torch.Tensor,
+                              intrinsics: torch.Tensor,
+                              pose: torch.Tensor,
+                              depth_candidates: torch.Tensor,
+                              depth_from_argmax: bool = False,
+                              pred_bidir_depth: bool = False,
+                              ) -> Tuple[torch.Tensor, torch.Tensor]:
     b, c, h, w = feature0.size()
     assert depth_candidates.dim() == 4  # [B, D, H, W]
     scale_factor = c ** 0.5
@@ -236,9 +235,9 @@ def correlation_softmax_depth(feature0, feature1,
     return depth, match_prob
 
 
-def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth,
-                                    clamp_min_depth=1e-3,
-                                    ):
+def warp_with_pose_depth_candidates(feature1: torch.Tensor, intrinsics: torch.Tensor, pose: torch.Tensor, depth: torch.Tensor,
+                                    clamp_min_depth: float = 1e-3,
+                                    ) -> torch.Tensor:
     """
     feature1: [B, C, H, W]
     intrinsics: [B, 3, 3]
diff --git a/unimatch/transformer.py b/unimatch/transformer.py
index a93660c..1d38332 100755
--- a/unimatch/transformer.py
+++ b/unimatch/transformer.py
@@ -1,5 +1,6 @@
 import torch
 import torch.nn as nn
+from typing import Optional, Tuple
 
 from .attention import (single_head_full_attention, single_head_split_window_attention,
                         single_head_full_attention_1d, single_head_split_window_attention_1d)
@@ -8,10 +9,10 @@
 
 class TransformerLayer(nn.Module):
     def __init__(self,
-                 d_model=128,
-                 nhead=1,
-                 no_ffn=False,
-                 ffn_dim_expansion=4,
+                 d_model: int = 128,
+                 nhead: int = 1,
+                 no_ffn: bool = False,
+                 ffn_dim_expansion: int = 4,
                  ):
         super(TransformerLayer, self).__init__()
 
@@ -28,7 +29,14 @@ def __init__(self,
 
         self.norm1 = nn.LayerNorm(d_model)
 
+        self.single_head_split_window_attention = single_head_split_window_attention()
+        self.single_head_split_window_attention_1d = single_head_split_window_attention_1d()
+        self.single_head_full_attention = single_head_full_attention()
+        self.single_head_full_attention_1d = single_head_full_attention_1d()
+
         # no ffn after self-attn, with ffn after cross-attn
+        self.mlp = nn.Sequential()
+        self.norm2 = nn.Sequential()
         if not self.no_ffn:
             in_channels = d_model * 2
             self.mlp = nn.Sequential(
@@ -39,15 +47,14 @@ def __init__(self,
 
             self.norm2 = nn.LayerNorm(d_model)
 
-    def forward(self, source, target,
-                height=None,
-                width=None,
-                shifted_window_attn_mask=None,
-                shifted_window_attn_mask_1d=None,
-                attn_type='swin',
-                with_shift=False,
-                attn_num_splits=None,
-                ):
+    def forward(self, source: torch.Tensor, target: torch.Tensor,
+                height: Optional[int] = None,
+                width: Optional[int] = None,
+                shifted_window_attn_mask: Optional[torch.Tensor] = None,
+                shifted_window_attn_mask_1d: Optional[torch.Tensor] = None,
+                attn_type: str = 'swin',
+                with_shift: bool = False,
+                attn_num_splits: int = 0) -> torch.Tensor:
         # source, target: [B, L, C]
         query, key, value = source, target, target
 
@@ -65,7 +72,7 @@ def forward(self, source, target,
                 # without bringing obvious performance gains and thus the implementation is removed
                 raise NotImplementedError
             else:
-                message = single_head_split_window_attention(query, key, value,
+                message = self.single_head_split_window_attention(query, key, value,
                                                              num_splits=attn_num_splits,
                                                              with_shift=with_shift,
                                                              h=height,
@@ -79,7 +86,7 @@ def forward(self, source, target,
             else:
                 if is_self_attn:
                     if attn_num_splits > 1:
-                        message = single_head_split_window_attention(query, key, value,
+                        message = self.single_head_split_window_attention(query, key, value,
                                                                      num_splits=attn_num_splits,
                                                                      with_shift=with_shift,
                                                                      h=height,
@@ -88,11 +95,11 @@ def forward(self, source, target,
                                                                      )
                     else:
                         # full 2d attn
-                        message = single_head_full_attention(query, key, value)  # [N, L, C]
+                        message = self.single_head_full_attention(query, key, value)  # [N, L, C]
 
                 else:
                     # cross attn 1d
-                    message = single_head_full_attention_1d(query, key, value,
+                    message = self.single_head_full_attention_1d(query, key, value,
                                                             h=height,
                                                             w=width,
                                                             )
@@ -104,7 +111,7 @@ def forward(self, source, target,
                 if is_self_attn:
                     if attn_num_splits > 1:
                         # self attn shift window
-                        message = single_head_split_window_attention(query, key, value,
+                        message = self.single_head_split_window_attention(query, key, value,
                                                                      num_splits=attn_num_splits,
                                                                      with_shift=with_shift,
                                                                      h=height,
@@ -113,12 +120,12 @@ def forward(self, source, target,
                                                                      )
                     else:
                         # full 2d attn
-                        message = single_head_full_attention(query, key, value)  # [N, L, C]
+                        message = self.single_head_full_attention(query, key, value)  # [N, L, C]
                 else:
                     if attn_num_splits > 1:
                         assert shifted_window_attn_mask_1d is not None
                         # cross attn 1d shift
-                        message = single_head_split_window_attention_1d(query, key, value,
+                        message = self.single_head_split_window_attention_1d(query, key, value,
                                                                         num_splits=attn_num_splits,
                                                                         with_shift=with_shift,
                                                                         h=height,
@@ -126,13 +133,13 @@ def forward(self, source, target,
                                                                         attn_mask=shifted_window_attn_mask_1d,
                                                                         )
                     else:
-                        message = single_head_full_attention_1d(query, key, value,
+                        message = self.single_head_full_attention_1d(query, key, value,
                                                                 h=height,
                                                                 w=width,
                                                                 )
 
         else:
-            message = single_head_full_attention(query, key, value)  # [B, L, C]
+            message = self.single_head_full_attention(query, key, value)  # [B, L, C]
 
         message = self.merge(message)  # [B, L, C]
         message = self.norm1(message)
@@ -148,9 +155,9 @@ class TransformerBlock(nn.Module):
     """self attention + cross attention + FFN"""
 
     def __init__(self,
-                 d_model=128,
-                 nhead=1,
-                 ffn_dim_expansion=4,
+                 d_model: int = 128,
+                 nhead: int = 1,
+                 ffn_dim_expansion: int = 4,
                  ):
         super(TransformerBlock, self).__init__()
 
@@ -162,18 +169,19 @@ def __init__(self,
 
         self.cross_attn_ffn = TransformerLayer(d_model=d_model,
                                                nhead=nhead,
+                                               no_ffn=False,
                                                ffn_dim_expansion=ffn_dim_expansion,
                                                )
 
-    def forward(self, source, target,
-                height=None,
-                width=None,
-                shifted_window_attn_mask=None,
-                shifted_window_attn_mask_1d=None,
-                attn_type='swin',
-                with_shift=False,
-                attn_num_splits=None,
-                ):
+    def forward(self, source: torch.Tensor, target: torch.Tensor,
+                height: Optional[int] = None,
+                width: Optional[int] = None,
+                shifted_window_attn_mask: Optional[torch.Tensor] = None,
+                shifted_window_attn_mask_1d: Optional[torch.Tensor] = None,
+                attn_type:str = 'swin',
+                with_shift:bool = False,
+                attn_num_splits: int = 0
+                ) -> torch.Tensor:
         # source, target: [B, L, C]
 
         # self attention
@@ -181,6 +189,7 @@ def forward(self, source, target,
                                 height=height,
                                 width=width,
                                 shifted_window_attn_mask=shifted_window_attn_mask,
+                                shifted_window_attn_mask_1d=None,
                                 attn_type=attn_type,
                                 with_shift=with_shift,
                                 attn_num_splits=attn_num_splits,
@@ -202,15 +211,17 @@ def forward(self, source, target,
 
 class FeatureTransformer(nn.Module):
     def __init__(self,
-                 num_layers=6,
-                 d_model=128,
-                 nhead=1,
-                 ffn_dim_expansion=4,
+                 num_layers: int = 6,
+                 d_model: int = 128,
+                 nhead: int = 1,
+                 ffn_dim_expansion: int = 4,
                  ):
         super(FeatureTransformer, self).__init__()
 
         self.d_model = d_model
         self.nhead = nhead
+        self.generate_shift_window_attn_mask = generate_shift_window_attn_mask()
+        self.generate_shift_window_attn_mask_1d = generate_shift_window_attn_mask_1d()
 
         self.layers = nn.ModuleList([
             TransformerBlock(d_model=d_model,
@@ -223,11 +234,11 @@ def __init__(self,
             if p.dim() > 1:
                 nn.init.xavier_uniform_(p)
 
-    def forward(self, feature0, feature1,
-                attn_type='swin',
-                attn_num_splits=None,
-                **kwargs,
-                ):
+
+    def forward(self, feature0: torch.Tensor, feature1: torch.Tensor,
+                attn_type: str = 'swin',
+                attn_num_splits: int = 0
+                ) -> Tuple[torch.Tensor, torch.Tensor]:
 
         b, c, h, w = feature0.shape
         assert self.d_model == c
@@ -242,7 +253,7 @@ def forward(self, feature0, feature1,
             window_size_w = w // attn_num_splits
 
             # compute attn mask once
-            shifted_window_attn_mask = generate_shift_window_attn_mask(
+            shifted_window_attn_mask = self.generate_shift_window_attn_mask(
                 input_resolution=(h, w),
                 window_size_h=window_size_h,
                 window_size_w=window_size_w,
@@ -258,7 +269,7 @@ def forward(self, feature0, feature1,
             window_size_w = w // attn_num_splits
 
             # compute attn mask once
-            shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
+            shifted_window_attn_mask_1d = self.generate_shift_window_attn_mask_1d(
                 input_w=w,
                 window_size_w=window_size_w,
                 shift_size_w=window_size_w // 2,
diff --git a/unimatch/trident_conv.py b/unimatch/trident_conv.py
index 445663c..baa5837 100755
--- a/unimatch/trident_conv.py
+++ b/unimatch/trident_conv.py
@@ -5,6 +5,7 @@
 from torch import nn
 from torch.nn import functional as F
 from torch.nn.modules.utils import _pair
+from typing import List
 
 
 class MultiScaleTridentConv(nn.Module):
@@ -61,7 +62,7 @@ def __init__(
         if self.bias is not None:
             nn.init.constant_(self.bias, 0)
 
-    def forward(self, inputs):
+    def forward(self, inputs: List[torch.Tensor]):
         num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
         assert len(inputs) == num_branch
 
diff --git a/unimatch/unimatch.py b/unimatch/unimatch.py
index 96db16e..9889ace 100755
--- a/unimatch/unimatch.py
+++ b/unimatch/unimatch.py
@@ -1,6 +1,7 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from typing import Optional, List, Dict
 
 from .backbone import CNNEncoder
 from .transformer import FeatureTransformer
@@ -43,6 +44,8 @@ def __init__(self,
 
         # propagation with self-attn
         self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
+        self.feature_add_position = feature_add_position(feature_channels)
+        self.upsampler = nn.Sequential()
 
         if not self.reg_refine or task == 'depth':
             # convex upsampling simiar to RAFT
@@ -78,13 +81,14 @@ def extract_feature(self, img0, img1):
 
         return feature0, feature1
 
-    def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
-                      is_depth=False):
+    def upsample_flow(self, flow: torch.Tensor, feature: Optional[torch.Tensor], bilinear: bool = False, upsample_factor: float = 8,
+                      is_depth: bool = False) -> torch.Tensor:
         if bilinear:
-            multiplier = 1 if is_depth else upsample_factor
+            multiplier = 1.0 if is_depth else upsample_factor
             up_flow = F.interpolate(flow, scale_factor=upsample_factor,
                                     mode='bilinear', align_corners=True) * multiplier
         else:
+            assert feature is not None
             concat = torch.cat((flow, feature), dim=1)
             mask = self.upsampler(concat)
             up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
@@ -92,23 +96,21 @@ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
 
         return up_flow
 
-    def forward(self, img0, img1,
-                attn_type=None,
-                attn_splits_list=None,
-                corr_radius_list=None,
-                prop_radius_list=None,
-                num_reg_refine=1,
-                pred_bidir_flow=False,
-                task='flow',
-                intrinsics=None,
-                pose=None,  # relative pose transform
-                min_depth=1. / 0.5,  # inverse depth range
-                max_depth=1. / 10,
-                num_depth_candidates=64,
-                depth_from_argmax=False,
-                pred_bidir_depth=False,
-                **kwargs,
-                ):
+    def forward(self, img0: torch.Tensor, img1: torch.Tensor,
+                attn_type: str,
+                attn_splits_list: List[int],
+                corr_radius_list: List[int],
+                prop_radius_list: List[int],
+                num_reg_refine: int = 1,
+                pred_bidir_flow: bool = False,
+                task: str = 'flow',
+                intrinsics: Optional[torch.Tensor] = None,
+                pose: torch.Tensor = None,  # relative pose transform
+                min_depth:float = 1. / 0.5,  # inverse depth range
+                max_depth:float = 1. / 10,
+                num_depth_candidates: int = 64,
+                depth_from_argmax: bool = False,
+                pred_bidir_depth: bool = False):
 
         if pred_bidir_flow:
             assert task == 'flow'
@@ -116,8 +118,8 @@ def forward(self, img0, img1,
         if task == 'depth':
             assert self.num_scales == 1  # multi-scale depth model is not supported yet
 
-        results_dict = {}
-        flow_preds = []
+        results_dict: Dict[str, List[torch.Tensor]] = {}
+        flow_preds: List[torch.Tensor] = []
 
         if task == 'flow':
             # stereo and depth tasks have normalized img in dataloader
@@ -126,7 +128,7 @@ def forward(self, img0, img1,
         # list of features, resolution low to high
         feature0_list, feature1_list = self.extract_feature(img0, img1)  # list of features
 
-        flow = None
+        flow: Optional[torch.Tensor] = None
 
         if task != 'depth':
             assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
@@ -146,14 +148,19 @@ def forward(self, img0, img1,
 
             if task == 'depth':
                 # scale intrinsics
+                assert intrinsics is not None
                 intrinsics_curr = intrinsics.clone()
                 intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
+            else:
+                intrinsics_curr = torch.zeros(1, 1)
 
             if scale_idx > 0:
                 assert task != 'depth'  # not supported for multi-scale depth model
-                flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
+                assert flow is not None
+                flow = F.interpolate(flow, scale_factor=2.0, mode='bilinear', align_corners=True) * 2
 
             if flow is not None:
+                assert flow is not None
                 assert task != 'depth'
                 flow = flow.detach()
 
@@ -163,19 +170,21 @@ def forward(self, img0, img1,
                     zeros = torch.zeros_like(flow)  # [B, 1, H, W]
                     # NOTE: reverse disp, disparity is positive
                     displace = torch.cat((-flow, zeros), dim=1)  # [B, 2, H, W]
-                    feature1 = flow_warp(feature1, displace)  # [B, C, H, W]
+                    feature1, _ = flow_warp(feature1, displace)  # [B, C, H, W]
                 elif task == 'flow':
-                    feature1 = flow_warp(feature1, flow)  # [B, C, H, W]
+                    feature1, _ = flow_warp(feature1, flow)  # [B, C, H, W]
                 else:
                     raise NotImplementedError
 
             attn_splits = attn_splits_list[scale_idx]
             if task != 'depth':
                 corr_radius = corr_radius_list[scale_idx]
+            else:
+                corr_radius = 0
             prop_radius = prop_radius_list[scale_idx]
 
             # add position to features
-            feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
+            feature0, feature1 = self.feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
 
             # Transformer
             feature0, feature1 = self.transformer(feature0, feature1,
@@ -216,7 +225,12 @@ def forward(self, img0, img1,
                         raise NotImplementedError
 
             # flow or residual flow
-            flow = flow + flow_pred if flow is not None else flow_pred
+            if flow is not None:
+                assert flow is not None
+                flow = flow + flow_pred
+            else:
+                assert flow is None
+                flow = flow_pred
 
             if task == 'stereo':
                 flow = flow.clamp(min=0)  # positive disparity
@@ -269,6 +283,8 @@ def forward(self, img0, img1,
                                                      is_depth=task == 'depth')
                         flow_preds.append(flow_up)
 
+                    if isinstance(num_reg_refine, tuple):
+                        num_reg_refine = num_reg_refine[0]
                     assert num_reg_refine > 0
                     for refine_iter_idx in range(num_reg_refine):
                         flow = flow.detach()
@@ -292,7 +308,7 @@ def forward(self, img0, img1,
                                                                        dim=0), torch.cat((feature1_ori,
                                                                                           feature0_ori), dim=0)
 
-                            flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
+                            flow_from_depth, _ = compute_flow_with_depth_pose(1. / flow.squeeze(1),
                                                                            intrinsics_curr,
                                                                            extrinsics_rel=pose,
                                                                            )
diff --git a/unimatch/utils.py b/unimatch/utils.py
index 0c3dbea..3ef9cf5 100755
--- a/unimatch/utils.py
+++ b/unimatch/utils.py
@@ -1,26 +1,27 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 from .position import PositionEmbeddingSine
+from typing import Tuple
 
-
-def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
+def generate_window_grid(h_min: int, h_max: int, w_min: int, w_max: int, len_h: int, len_w: int, device: torch.device = None) -> torch.Tensor:
     assert device is not None
 
     x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
                            torch.linspace(h_min, h_max, len_h, device=device)],
-                          )
+                          indexing = 'ij')
     grid = torch.stack((x, y), -1).transpose(0, 1).float()  # [H, W, 2]
 
     return grid
 
 
-def normalize_coords(coords, h, w):
+def normalize_coords(coords: torch.Tensor, h: int, w: int) -> torch.Tensor:
     # coords: [B, H, W, 2]
     c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
     return (coords - c) / c  # [-1, 1]
 
 
-def normalize_img(img0, img1):
+def normalize_img(img0: torch.Tensor, img1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     # loaded images are in [0, 255]
     # normalize by ImageNet mean and std
     mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
@@ -30,109 +31,113 @@ def normalize_img(img0, img1):
 
     return img0, img1
 
-
-def split_feature(feature,
-                  num_splits=2,
-                  channel_last=False,
-                  ):
-    if channel_last:  # [B, H, W, C]
-        b, h, w, c = feature.size()
-        assert h % num_splits == 0 and w % num_splits == 0
-
-        b_new = b * num_splits * num_splits
-        h_new = h // num_splits
-        w_new = w // num_splits
-
-        feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
-                               ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c)  # [B*K*K, H/K, W/K, C]
-    else:  # [B, C, H, W]
-        b, c, h, w = feature.size()
-        assert h % num_splits == 0 and w % num_splits == 0
-
-        b_new = b * num_splits * num_splits
-        h_new = h // num_splits
-        w_new = w // num_splits
-
-        feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
-                               ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new)  # [B*K*K, C, H/K, W/K]
-
-    return feature
-
-
-def merge_splits(splits,
-                 num_splits=2,
-                 channel_last=False,
-                 ):
-    if channel_last:  # [B*K*K, H/K, W/K, C]
-        b, h, w, c = splits.size()
-        new_b = b // num_splits // num_splits
-
-        splits = splits.view(new_b, num_splits, num_splits, h, w, c)
-        merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
-            new_b, num_splits * h, num_splits * w, c)  # [B, H, W, C]
-    else:  # [B*K*K, C, H/K, W/K]
-        b, c, h, w = splits.size()
-        new_b = b // num_splits // num_splits
-
-        splits = splits.view(new_b, num_splits, num_splits, c, h, w)
-        merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
-            new_b, c, num_splits * h, num_splits * w)  # [B, C, H, W]
-
-    return merge
-
-
-def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
-                                    shift_size_h, shift_size_w, device=torch.device('cuda')):
-    # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
-    # calculate attention mask for SW-MSA
-    h, w = input_resolution
-    img_mask = torch.zeros((1, h, w, 1)).to(device)  # 1 H W 1
-    h_slices = (slice(0, -window_size_h),
-                slice(-window_size_h, -shift_size_h),
-                slice(-shift_size_h, None))
-    w_slices = (slice(0, -window_size_w),
-                slice(-window_size_w, -shift_size_w),
-                slice(-shift_size_w, None))
-    cnt = 0
-    for h in h_slices:
-        for w in w_slices:
-            img_mask[:, h, w, :] = cnt
-            cnt += 1
-
-    mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
-
-    mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
-    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
-    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
-    return attn_mask
-
-
-def feature_add_position(feature0, feature1, attn_splits, feature_channels):
-    pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
-
-    if attn_splits > 1:  # add position in splited window
-        feature0_splits = split_feature(feature0, num_splits=attn_splits)
-        feature1_splits = split_feature(feature1, num_splits=attn_splits)
-
-        position = pos_enc(feature0_splits)
-
-        feature0_splits = feature0_splits + position
-        feature1_splits = feature1_splits + position
-
-        feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
-        feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
-    else:
-        position = pos_enc(feature0)
-
-        feature0 = feature0 + position
-        feature1 = feature1 + position
-
-    return feature0, feature1
-
-
-def upsample_flow_with_mask(flow, up_mask, upsample_factor,
-                            is_depth=False):
+class split_feature(nn.Module):
+    def forward(self, feature: torch.Tensor, num_splits: int = 2, channel_last: bool = False) -> torch.Tensor:
+        if channel_last:  # [B, H, W, C]
+            b, h, w, c = feature.size()
+            assert h % num_splits == 0 and w % num_splits == 0
+
+            b_new = b * num_splits * num_splits
+            h_new = h // num_splits
+            w_new = w // num_splits
+
+            feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
+                                ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c)  # [B*K*K, H/K, W/K, C]
+        else:  # [B, C, H, W]
+            b, c, h, w = feature.size()
+            assert h % num_splits == 0 and w % num_splits == 0
+
+            b_new = b * num_splits * num_splits
+            h_new = h // num_splits
+            w_new = w // num_splits
+
+            feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
+                                ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new)  # [B*K*K, C, H/K, W/K]
+
+        return feature
+
+class merge_splits(nn.Module):
+    def forward(self, splits: torch.Tensor, num_splits: int = 2, channel_last: bool = False) -> torch.Tensor:
+        if channel_last:  # [B*K*K, H/K, W/K, C]
+            b, h, w, c = splits.size()
+            new_b = b // num_splits // num_splits
+
+            splits = splits.view(new_b, num_splits, num_splits, h, w, c)
+            merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
+                new_b, num_splits * h, num_splits * w, c)  # [B, H, W, C]
+        else:  # [B*K*K, C, H/K, W/K]
+            b, c, h, w = splits.size()
+            new_b = b // num_splits // num_splits
+
+            splits = splits.view(new_b, num_splits, num_splits, c, h, w)
+            merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
+                new_b, c, num_splits * h, num_splits * w)  # [B, C, H, W]
+
+        return merge
+
+class generate_shift_window_attn_mask(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.split_feature = split_feature()
+
+    def forward(self, input_resolution: Tuple[int, int], window_size_h: int, window_size_w: int, shift_size_h: int, shift_size_w: int, device: torch.device = torch.device('cuda')) -> torch.Tensor:
+        # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+        # calculate attention mask for SW-MSA
+        h, w = input_resolution
+
+        mask1 = torch.ones((h - window_size_h,            w - window_size_w           )).to(device) * 0
+        mask2 = torch.ones((h - window_size_h,            window_size_w - shift_size_w)).to(device) * 1
+        mask3 = torch.ones((h - window_size_h,            shift_size_w                )).to(device) * 2
+        mask4 = torch.ones((window_size_h - shift_size_h, w - window_size_w           )).to(device) * 3
+        mask5 = torch.ones((window_size_h - shift_size_h, window_size_w - shift_size_w)).to(device) * 4
+        mask6 = torch.ones((window_size_h - shift_size_h, shift_size_w                )).to(device) * 5
+        mask7 = torch.ones((shift_size_h,                 w - window_size_w           )).to(device) * 6
+        mask8 = torch.ones((shift_size_h,                 window_size_w - shift_size_w)).to(device) * 7
+        mask9 = torch.ones((shift_size_h,                 shift_size_w                )).to(device) * 8
+        # Concatenate the masks to create the full mask
+        upper_mask  = torch.cat([mask1, mask2, mask3], dim=1)
+        middle_mask = torch.cat([mask4, mask5, mask6], dim=1)
+        lower_mask  = torch.cat([mask7, mask8, mask9], dim=1)
+        img_mask = torch.cat([upper_mask, middle_mask, lower_mask], dim=0).unsqueeze(0).unsqueeze(-1) # Add extra dimensions for batch size and channels
+
+        mask_windows = self.split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
+
+        mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        return attn_mask
+
+class feature_add_position(nn.Module):
+    def __init__(self, feature_channels: int):
+        super().__init__()
+        self.split_feature = split_feature()
+        self.merge_splits = merge_splits()
+        self.pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
+
+    def forward(self, feature0: torch.Tensor, feature1: torch.Tensor, attn_splits: int, feature_channels: int) -> Tuple[torch.Tensor, torch.Tensor]:
+        if attn_splits > 1:  # add position in splited window
+            feature0_splits = self.split_feature(feature0, num_splits=attn_splits)
+            feature1_splits = self.split_feature(feature1, num_splits=attn_splits)
+
+            position = self.pos_enc(feature0_splits)
+
+            feature0_splits = feature0_splits + position
+            feature1_splits = feature1_splits + position
+
+            feature0 = self.merge_splits(feature0_splits, num_splits=attn_splits)
+            feature1 = self.merge_splits(feature1_splits, num_splits=attn_splits)
+        else:
+            position = self.pos_enc(feature0)
+
+            feature0 = feature0 + position
+            feature1 = feature1 + position
+
+        return feature0, feature1
+
+
+def upsample_flow_with_mask(flow: torch.Tensor, up_mask: torch.Tensor, upsample_factor: int,
+                            is_depth: bool = False) -> torch.Tensor:
     # convex upsampling following raft
 
     mask = up_mask
@@ -151,38 +156,32 @@ def upsample_flow_with_mask(flow, up_mask, upsample_factor,
 
     return up_flow
 
+class split_feature_1d(nn.Module):
+    def forward(self, feature: torch.Tensor, num_splits: int = 2) -> torch.Tensor:
+        # feature: [B, W, C]
+        b, w, c = feature.size()
+        assert w % num_splits == 0
 
-def split_feature_1d(feature,
-                     num_splits=2,
-                     ):
-    # feature: [B, W, C]
-    b, w, c = feature.size()
-    assert w % num_splits == 0
-
-    b_new = b * num_splits
-    w_new = w // num_splits
-
-    feature = feature.view(b, num_splits, w // num_splits, c
-                           ).view(b_new, w_new, c)  # [B*K, W/K, C]
-
-    return feature
+        b_new = b * num_splits
+        w_new = w // num_splits
 
+        feature = feature.view(b, num_splits, w // num_splits, c
+                            ).view(b_new, w_new, c)  # [B*K, W/K, C]
 
-def merge_splits_1d(splits,
-                    h,
-                    num_splits=2,
-                    ):
-    b, w, c = splits.size()
-    new_b = b // num_splits // h
+        return feature
 
-    splits = splits.view(new_b, h, num_splits, w, c)
-    merge = splits.view(
-        new_b, h, num_splits * w, c)  # [B, H, W, C]
+class merge_splits_1d(nn.Module):
+    def forward(self, splits: torch.Tensor, h: int, num_splits: int = 2) -> torch.Tensor:
+        b, w, c = splits.size()
+        new_b = b // num_splits // h
 
-    return merge
+        splits = splits.view(new_b, h, num_splits, w, c)
+        merge = splits.view(
+            new_b, h, num_splits * w, c)  # [B, H, W, C]
 
+        return merge
 
-def window_partition_1d(x, window_size_w):
+def window_partition_1d(x: torch.Tensor, window_size_w: int) -> torch.Tensor:
     """
     Args:
         x: (B, W, C)
@@ -195,22 +194,18 @@ def window_partition_1d(x, window_size_w):
     x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
     return x
 
-
-def generate_shift_window_attn_mask_1d(input_w, window_size_w,
-                                       shift_size_w, device=torch.device('cuda')):
-    # calculate attention mask for SW-MSA
-    img_mask = torch.zeros((1, input_w, 1)).to(device)  # 1 W 1
-    w_slices = (slice(0, -window_size_w),
-                slice(-window_size_w, -shift_size_w),
-                slice(-shift_size_w, None))
-    cnt = 0
-    for w in w_slices:
-        img_mask[:, w, :] = cnt
-        cnt += 1
-
-    mask_windows = window_partition_1d(img_mask, window_size_w)  # nW, window_size, 1
-    mask_windows = mask_windows.view(-1, window_size_w)
-    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # nW, window_size, window_size
-    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
-    return attn_mask
+class generate_shift_window_attn_mask_1d(nn.Module):
+    def forward(self, input_w: int, window_size_w: int, shift_size_w: int, device: torch.device = torch.device('cuda')) -> torch.Tensor:
+        # calculate attention mask for SW-MSA
+        mask1 = torch.ones((input_w - window_size_w     )).to(device) * 0
+        mask2 = torch.ones((window_size_w - shift_size_w)).to(device) * 1
+        mask3 = torch.ones((shift_size_w                )).to(device) * 2
+        # Concatenate the masks to create the full mask
+        img_mask = torch.cat([mask1, mask2, mask3], dim=0).unsqueeze(0).unsqueeze(-1)
+
+        mask_windows = window_partition_1d(img_mask, window_size_w)  # nW, window_size, 1
+        mask_windows = mask_windows.view(-1, window_size_w)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # nW, window_size, window_size
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        return attn_mask
diff --git a/utils/utils.py b/utils/utils.py
index 73d780f..f8e87f8 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -61,9 +61,10 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zer
 def coords_grid(batch, ht, wd, normalize=False):
     if normalize:  # [-1, 1]
         coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1,
-                                2 * torch.arange(wd) / (wd - 1) - 1)
+                                2 * torch.arange(wd) / (wd - 1) - 1,
+                                indexing = 'ij')
     else:
-        coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+        coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing = 'ij')
     coords = torch.stack(coords[::-1], dim=0).float()
     return coords[None].repeat(batch, 1, 1, 1)  # [B, 2, H, W]