Skip to content

Commit fb9ad99

Browse files
committed
apply PR swz30#56
1 parent 7a36b56 commit fb9ad99

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

basicsr/models/archs/restormer_arch.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,20 @@ def forward(self, x):
114114
qkv = self.qkv_dwconv(self.qkv(x))
115115
q,k,v = qkv.chunk(3, dim=1)
116116

117-
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118-
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119-
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
117+
'''
118+
https://github.com/swz30/Restormer/pull/56
119+
120+
Make q, k, and v contiguous to get better performance for normalize.
121+
After the rearrange operations for q, k, and v, normalizations on the last dim for q and k will be applied.
122+
The non-contiguous memory format makes the performance of normalize on the last dim poor.
123+
'''
124+
# q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
125+
# k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
126+
# v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
127+
128+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
129+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
130+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
120131

121132
q = torch.nn.functional.normalize(q, dim=-1)
122133
k = torch.nn.functional.normalize(k, dim=-1)

0 commit comments

Comments
 (0)