Skip to content

New kernels for concat #1764

Open
@jianyizh

Description

@jianyizh

🚀 The feature, motivation and pitch

Follow cuda on CatArrayBatchedCopy_alignedK_contig and other possible kernels.
For this example, BMG takes 4.15ms, 4080 takes 0.7ms

import torch
import time
input_list = [torch.randn((2048,512,7,7),dtype=torch.float16).to(memory_format=torch.channels_last).to(torch.accelerator.current_accelerator())]
for _ in range(15):
    input_list.append(torch.randn((2048,32,7,7),dtype=torch.float16).to(memory_format=torch.channels_last).to(torch.accelerator.current_accelerator()))
r = torch.cat(input_list,1)
torch.accelerator.synchronize()
start = time.time()
for _ in range(20):
    r = torch.cat(input_list,1)
torch.accelerator.synchronize()
end = time.time()
print((end-start)/20*1000)

Alternatives

No response

Additional context

Need implement

  1. CatArrayBatchedCopy
  2. CatArrayBatchedCopy_contig
  3. CatArrayBatchedCopy_alignedK_contig

Metadata

Metadata

Assignees

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions