@@ -65,7 +65,7 @@ __device__ inline void printBool(bool b, int customIdx) {
6565__device__ inline void printFloat(float a, int customIdx) {
6666 int clusterId = cluster_id();
6767 int coreId = core_id();
68- if (clusterId == __CHECK_CLUSTER_ID && coreId == 0 ) {
68+ if (clusterId == __CHECK_CLUSTER_ID && coreId < 8 ) {
6969 printf("[printFloat_%d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
7070 customIdx, clusterId, coreId, a);
7171 }
@@ -76,7 +76,7 @@ __device__ inline void printFloat_spec(float a, int clu_id, int cid,
7676 int clusterId = cluster_id();
7777 int coreId = core_id();
7878 if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
79- printf("[printFloat_ %d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
79+ printf("[printFloat_spec_ %d] cluster_id = %d, core_id = %d, local_x[0] = %f \n",
8080 customIdx, clusterId, coreId, a);
8181 }
8282}
@@ -93,21 +93,41 @@ __device__ inline void printFloat_all(float a, int customIdx) {
9393__device__ inline void printInt(int a, int customIdx) {
9494 int clusterId = cluster_id();
9595 int coreId = core_id();
96- if (coreId == 0 ) {
96+ if (clusterId == __CHECK_CLUSTER_ID && coreId < 8 ) {
9797 printf("[printInt_%d] cluster_id = %d, core_id = %d, local_x[0] = %d \n",
9898 customIdx, clusterId, coreId, a);
9999 }
100100}
101101
102+ __device__ inline void printInt_spec(int a, int clu_id, int cid,
103+ int customIdx) {
104+ int clusterId = cluster_id();
105+ int coreId = core_id();
106+ if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
107+ printf("[printInt_spec_%d] cluster_id = %d, core_id = %d, local_x[0] = %d \n",
108+ customIdx, clusterId, coreId, a);
109+ }
110+ }
111+
102112__device__ inline void printInt64(int64_t a, int customIdx) {
103113 int clusterId = cluster_id();
104114 int coreId = core_id();
105- if (clusterId == __CHECK_CLUSTER_ID && coreId == 0 ) {
115+ if (clusterId == __CHECK_CLUSTER_ID && coreId < 8 ) {
106116 printf("[printInt64_%d] cluster_id = %d, core_id = %d, local_x[0] = %ld \n",
107117 customIdx, clusterId, coreId, a);
108118 }
109119}
110120
121+ __device__ inline void printInt64_spec(int64_t a, int clu_id, int cid,
122+ int customIdx) {
123+ int clusterId = cluster_id();
124+ int coreId = core_id();
125+ if (clusterId == __CHECK_CLUSTER_ID && coreId == cid) {
126+ printf("[printInt64_spec_%d] cluster_id = %d, core_id = %d, local_x[0] = %ld \n",
127+ customIdx, clusterId, coreId, a);
128+ }
129+ }
130+
111131__device__ inline void printMMaOp(int a, int b, int c, int isAcc) {
112132 int clusterId = cluster_id();
113133 int coreId = core_id();
@@ -287,6 +307,27 @@ __device__ inline void vstore2_lm(bfloat16* ptr, float32x16_t vl, float32x16_t v
287307 vstore_lm_float16x32(reinterpret_cast<float16*>(ptr), reinterpret_cast<float16x32_t>(vl));
288308}
289309
310+
311+ __device__ inline void vstore2_lm_unordered(bfloat16* ptr, float32x16_t veven, float32x16_t vodd) {
312+ int mask_even = sveq_uint32x16(0, svand_uint32x16(0x10000, reinterpret_cast<uint32x16_t>(veven)));
313+ int mask_odd = sveq_uint32x16(0, svand_uint32x16(0x10000, reinterpret_cast<uint32x16_t>(vodd)));
314+
315+ veven = reinterpret_cast<float32x16_t>(
316+ svadd_uint32x16_mh(0x8000, reinterpret_cast<uint32x16_t>(veven), reinterpret_cast<uint32x16_t>(veven), ~mask_even));
317+ vodd = reinterpret_cast<float32x16_t>(
318+ svadd_uint32x16_mh(0x8000, reinterpret_cast<uint32x16_t>(vodd), reinterpret_cast<uint32x16_t>(vodd), ~mask_odd));
319+ veven = reinterpret_cast<float32x16_t>(
320+ svadd_uint32x16_mh(0x7FFF, reinterpret_cast<uint32x16_t>(veven), reinterpret_cast<uint32x16_t>(veven), mask_even));
321+ vodd = reinterpret_cast<float32x16_t>(
322+ svadd_uint32x16_mh(0x7FFF, reinterpret_cast<uint32x16_t>(vodd), reinterpret_cast<uint32x16_t>(vodd), mask_odd));
323+
324+ constexpr int mask = 0xaaaaaaaa; // 0b10101010101010101010101010101010
325+ vstore_lm_int16x32_mh(ptr, reinterpret_cast<int16x32_t>(veven), mask);
326+ constexpr int pose = 16;
327+ __asm__ __volatile__("vsrl.p %0, %1, %2" : "=&v"(vodd) : "r"(pose), "v"(vodd));
328+ vstore_lm_int16x32_mh(ptr, reinterpret_cast<int16x32_t>(vodd), (~mask));
329+ }
330+
290331static __device__ inline void taylor_sin(float *C1, float *C3, float *C5,
291332 float *C7, float *C9) {
292333 *C1 = 1;
0 commit comments