55#include < ATen/core/Tensor.h>
66#include < ATen/native/TensorIterator.h>
77
8- #include < ATen/native/xpu/sycl/CopyKernel.h>
98#include < ATen/native/xpu/sycl/Loops.h>
109
10+ #include < ATen/native/xpu/sycl/CopyKernel.h>
11+
1112namespace at ::native::xpu {
1213
1314template <typename scalar_t >
@@ -101,6 +102,21 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
101102 gpu_kernel (iter, CopyScalarFunc<Float8_e5m2fnuz>());
102103 break ;
103104 }
105+ } else if (dtype == kFloat8_e8m0fnu ) {
106+ switch (other_dtype) {
107+ case kFloat :
108+ gpu_kernel_nocast (iter, CastScalarFunc<float , Float8_e8m0fnu>());
109+ break ;
110+ case kHalf :
111+ gpu_kernel_nocast (iter, CastScalarFunc<Half, Float8_e8m0fnu>());
112+ break ;
113+ case kBFloat16 :
114+ gpu_kernel_nocast (iter, CastScalarFunc<BFloat16, Float8_e8m0fnu>());
115+ break ;
116+ default :
117+ gpu_kernel (iter, CopyScalarFunc<Float8_e8m0fnu>());
118+ break ;
119+ }
104120 } else {
105121 TORCH_CHECK (
106122 false ,
@@ -109,6 +125,16 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
109125 }
110126}
111127
128+ void float4_copy_kernel_xpu (TensorIteratorBase& iter) {
129+ ScalarType src_dtype = iter.dtype (1 );
130+
131+ if (src_dtype == kFloat4_e2m1fn_x2 ) {
132+ gpu_kernel_nocast (iter, CopyScalarFunc<Float4_e2m1fn_x2>());
133+ } else {
134+ TORCH_CHECK (false , " Copy from " , src_dtype, " to Float4_e2m1fn_x2 has not been supported." );
135+ }
136+ }
137+
112138void copy_kernel (TensorIteratorBase& iter) {
113139 ScalarType dtype = iter.common_dtype ();
114140 if (isQIntType (dtype)) {
@@ -117,6 +143,8 @@ void copy_kernel(TensorIteratorBase& iter) {
117143 });
118144 } else if (isFloat8Type (iter.dtype (0 ))) {
119145 float8_copy_kernel_xpu (iter);
146+ } else if (iter.dtype (0 ) == kFloat4_e2m1fn_x2 ) {
147+ float4_copy_kernel_xpu (iter);
120148 } else {
121149 AT_DISPATCH_V2 (
122150 dtype,
@@ -127,11 +155,8 @@ void copy_kernel(TensorIteratorBase& iter) {
127155 kBool ,
128156 kBFloat16 ,
129157 kComplexHalf ,
130- AT_EXPAND (AT_BAREBONES_UNSIGNED_TYPES),
131- kFloat8_e4m3fn ,
132- kFloat8_e5m2 ,
133- kFloat8_e4m3fnuz ,
134- kFloat8_e5m2fnuz );
158+ AT_EXPAND (AT_FLOAT8_TYPES),
159+ AT_EXPAND (AT_BAREBONES_UNSIGNED_TYPES));
135160 }
136161}
137162
0 commit comments