@@ -114,9 +114,20 @@ def forward(self, x):
114114 qkv = self .qkv_dwconv (self .qkv (x ))
115115 q ,k ,v = qkv .chunk (3 , dim = 1 )
116116
117- q = rearrange (q , 'b (head c) h w -> b head c (h w)' , head = self .num_heads )
118- k = rearrange (k , 'b (head c) h w -> b head c (h w)' , head = self .num_heads )
119- v = rearrange (v , 'b (head c) h w -> b head c (h w)' , head = self .num_heads )
117+ '''
118+ https://github.com/swz30/Restormer/pull/56
119+
120+ Make q, k, and v contiguous to get better performance for normalize.
121+ After the rearrange operations for q, k, and v, normalizations on the last dim for q and k will be applied.
122+ The non-contiguous memory format makes the performance of normalize on the last dim poor.
123+ '''
124+ # q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
125+ # k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
126+ # v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
127+
128+ q = rearrange (q , 'b (head c) h w -> b head c (h w)' , head = self .num_heads ).contiguous (memory_format = torch .contiguous_format )
129+ k = rearrange (k , 'b (head c) h w -> b head c (h w)' , head = self .num_heads ).contiguous (memory_format = torch .contiguous_format )
130+ v = rearrange (v , 'b (head c) h w -> b head c (h w)' , head = self .num_heads ).contiguous (memory_format = torch .contiguous_format )
120131
121132 q = torch .nn .functional .normalize (q , dim = - 1 )
122133 k = torch .nn .functional .normalize (k , dim = - 1 )
0 commit comments