Open
Description
🚀 The feature, motivation and pitch
Follow cuda on CatArrayBatchedCopy_alignedK_contig and other possible kernels.
For this example, BMG takes 4.15ms, 4080 takes 0.7ms
import torch
import time
input_list = [torch.randn((2048,512,7,7),dtype=torch.float16).to(memory_format=torch.channels_last).to(torch.accelerator.current_accelerator())]
for _ in range(15):
input_list.append(torch.randn((2048,32,7,7),dtype=torch.float16).to(memory_format=torch.channels_last).to(torch.accelerator.current_accelerator()))
r = torch.cat(input_list,1)
torch.accelerator.synchronize()
start = time.time()
for _ in range(20):
r = torch.cat(input_list,1)
torch.accelerator.synchronize()
end = time.time()
print((end-start)/20*1000)
Alternatives
No response
Additional context
Need implement
- CatArrayBatchedCopy
- CatArrayBatchedCopy_contig
- CatArrayBatchedCopy_alignedK_contig