Skip to content

Commit a9884b2

Browse files
authored
[BACKEND] KUNLUNXIN xpu update to Commit 9855424(20251203) (#164)
1 parent a9da65c commit a9884b2

File tree

5 files changed

+603
-122
lines changed

5 files changed

+603
-122
lines changed

third_party/xpu/device/trigonometric.xpu

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ __device__ inline void printBool(bool b, int customIdx) {
6565
__device__ inline void printFloat(float a, int customIdx) {
6666
int clusterId = cluster_id();
6767
int coreId = core_id();
68-
if (clusterId == __CHECK_CLUSTER_ID && coreId == 0) {
68+
if (clusterId == __CHECK_CLUSTER_ID && coreId < 8) {
6969
printf("[printFloat_%d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
7070
customIdx, clusterId, coreId, a);
7171
}
@@ -76,7 +76,7 @@ __device__ inline void printFloat_spec(float a, int clu_id, int cid,
7676
int clusterId = cluster_id();
7777
int coreId = core_id();
7878
if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
79-
printf("[printFloat_%d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
79+
printf("[printFloat_spec_%d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
8080
customIdx, clusterId, coreId, a);
8181
}
8282
}
@@ -93,21 +93,41 @@ __device__ inline void printFloat_all(float a, int customIdx) {
9393
__device__ inline void printInt(int a, int customIdx) {
9494
int clusterId = cluster_id();
9595
int coreId = core_id();
96-
if (coreId == 0) {
96+
if (clusterId == __CHECK_CLUSTER_ID && coreId < 8) {
9797
printf("[printInt_%d] cluster_id = %d, core_id = %d, local_x[0] = %d \n",
9898
customIdx, clusterId, coreId, a);
9999
}
100100
}
101101

102+
__device__ inline void printInt_spec(int a, int clu_id, int cid,
103+
int customIdx) {
104+
int clusterId = cluster_id();
105+
int coreId = core_id();
106+
if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
107+
printf("[printInt_spec_%d] cluster_id = %d, core_id = %d, local_x[0] = %d \n",
108+
customIdx, clusterId, coreId, a);
109+
}
110+
}
111+
102112
__device__ inline void printInt64(int64_t a, int customIdx) {
103113
int clusterId = cluster_id();
104114
int coreId = core_id();
105-
if (clusterId == __CHECK_CLUSTER_ID && coreId == 0) {
115+
if (clusterId == __CHECK_CLUSTER_ID && coreId < 8) {
106116
printf("[printInt64_%d] cluster_id = %d, core_id = %d, local_x[0] = %ld \n",
107117
customIdx, clusterId, coreId, a);
108118
}
109119
}
110120

121+
__device__ inline void printInt64_spec(int64_t a, int clu_id, int cid,
122+
int customIdx) {
123+
int clusterId = cluster_id();
124+
int coreId = core_id();
125+
if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
126+
printf("[printInt64_spec_%d] cluster_id = %d, core_id = %d, local_x[0] = %ld \n",
127+
customIdx, clusterId, coreId, a);
128+
}
129+
}
130+
111131
__device__ inline void printMMaOp(int a, int b, int c, int isAcc) {
112132
int clusterId = cluster_id();
113133
int coreId = core_id();
@@ -287,6 +307,27 @@ __device__ inline void vstore2_lm(bfloat16* ptr, float32x16_t vl, float32x16_t v
287307
vstore_lm_float16x32(reinterpret_cast<float16*>(ptr), reinterpret_cast<float16x32_t>(vl));
288308
}
289309

310+
311+
__device__ inline void vstore2_lm_unordered(bfloat16* ptr, float32x16_t veven, float32x16_t vodd) {
312+
int mask_even = sveq_uint32x16(0, svand_uint32x16(0x10000, reinterpret_cast<uint32x16_t>(veven)));
313+
int mask_odd = sveq_uint32x16(0, svand_uint32x16(0x10000, reinterpret_cast<uint32x16_t>(vodd)));
314+
315+
veven = reinterpret_cast<float32x16_t>(
316+
svadd_uint32x16_mh(0x8000, reinterpret_cast<uint32x16_t>(veven), reinterpret_cast<uint32x16_t>(veven), ~mask_even));
317+
vodd = reinterpret_cast<float32x16_t>(
318+
svadd_uint32x16_mh(0x8000, reinterpret_cast<uint32x16_t>(vodd), reinterpret_cast<uint32x16_t>(vodd), ~mask_odd));
319+
veven = reinterpret_cast<float32x16_t>(
320+
svadd_uint32x16_mh(0x7FFF, reinterpret_cast<uint32x16_t>(veven), reinterpret_cast<uint32x16_t>(veven), mask_even));
321+
vodd = reinterpret_cast<float32x16_t>(
322+
svadd_uint32x16_mh(0x7FFF, reinterpret_cast<uint32x16_t>(vodd), reinterpret_cast<uint32x16_t>(vodd), mask_odd));
323+
324+
constexpr int mask = 0xaaaaaaaa; // 0b10101010101010101010101010101010
325+
vstore_lm_int16x32_mh(ptr, reinterpret_cast<int16x32_t>(veven), mask);
326+
constexpr int pose = 16;
327+
__asm__ __volatile__("vsrl.p %0, %1, %2" : "=&v"(vodd) : "r"(pose), "v"(vodd));
328+
vstore_lm_int16x32_mh(ptr, reinterpret_cast<int16x32_t>(vodd), (~mask));
329+
}
330+
290331
static __device__ inline void taylor_sin(float *C1, float *C3, float *C5,
291332
float *C7, float *C9) {
292333
*C1 = 1;

0 commit comments

Comments
 (0)