Skip to content

Commit 3bb86ae

Browse files
authored
Update CopyKernel.cpp
1 parent ac60ff0 commit 3bb86ae

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

src/ATen/native/xpu/sycl/CopyKernel.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include <ATen/core/Tensor.h>
66
#include <ATen/native/TensorIterator.h>
77

8-
#include <ATen/native/xpu/sycl/CopyKernel.h>
98
#include <ATen/native/xpu/sycl/Loops.h>
109

10+
#include <ATen/native/xpu/sycl/CopyKernel.h>
11+
1112
namespace at::native::xpu {
1213

1314
template <typename scalar_t>
@@ -101,6 +102,21 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
101102
gpu_kernel(iter, CopyScalarFunc<Float8_e5m2fnuz>());
102103
break;
103104
}
105+
} else if (dtype == kFloat8_e8m0fnu) {
106+
switch (other_dtype) {
107+
case kFloat:
108+
gpu_kernel_nocast(iter, CastScalarFunc<float, Float8_e8m0fnu>());
109+
break;
110+
case kHalf:
111+
gpu_kernel_nocast(iter, CastScalarFunc<Half, Float8_e8m0fnu>());
112+
break;
113+
case kBFloat16:
114+
gpu_kernel_nocast(iter, CastScalarFunc<BFloat16, Float8_e8m0fnu>());
115+
break;
116+
default:
117+
gpu_kernel(iter, CopyScalarFunc<Float8_e8m0fnu>());
118+
break;
119+
}
104120
} else {
105121
TORCH_CHECK(
106122
false,
@@ -109,6 +125,16 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
109125
}
110126
}
111127

128+
void float4_copy_kernel_xpu(TensorIteratorBase& iter) {
129+
ScalarType src_dtype = iter.dtype(1);
130+
131+
if (src_dtype == kFloat4_e2m1fn_x2) {
132+
gpu_kernel_nocast(iter, CopyScalarFunc<Float4_e2m1fn_x2>());
133+
} else {
134+
TORCH_CHECK(false, "Copy from ", src_dtype, " to Float4_e2m1fn_x2 has not been supported.");
135+
}
136+
}
137+
112138
void copy_kernel(TensorIteratorBase& iter) {
113139
ScalarType dtype = iter.common_dtype();
114140
if (isQIntType(dtype)) {
@@ -117,6 +143,8 @@ void copy_kernel(TensorIteratorBase& iter) {
117143
});
118144
} else if (isFloat8Type(iter.dtype(0))) {
119145
float8_copy_kernel_xpu(iter);
146+
} else if (iter.dtype(0) == kFloat4_e2m1fn_x2) {
147+
float4_copy_kernel_xpu(iter);
120148
} else {
121149
AT_DISPATCH_V2(
122150
dtype,
@@ -127,11 +155,8 @@ void copy_kernel(TensorIteratorBase& iter) {
127155
kBool,
128156
kBFloat16,
129157
kComplexHalf,
130-
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
131-
kFloat8_e4m3fn,
132-
kFloat8_e5m2,
133-
kFloat8_e4m3fnuz,
134-
kFloat8_e5m2fnuz);
158+
AT_EXPAND(AT_FLOAT8_TYPES),
159+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
135160
}
136161
}
137162

0 commit comments

Comments
 (0)