Skip to content

Commit 357e92f

Browse files
committed
IMB-MPI1: FP16/BF16 data_type support
1 parent be5500e commit 357e92f

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

src_cpp/MPI1/MPI1_suite.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
312312
set_description(
313313
"The argument after -data_type is a one from possible strings,\n"
314314
"Specifying that type will be used:\n"
315-
"byte, char, int, float, double\n"
315+
"byte, char, int, float, double, float16, bfloat16\n"
316316
"\n"
317317
"Example:\n"
318318
"-data_type char\n"
@@ -323,7 +323,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
323323
set_description(
324324
"The argument after -red_data_type is a one from possible strings,\n"
325325
"Specifying that type will be used:\n"
326-
"char, int, float, double\n"
326+
"char, int, float, double, float16, bfloat16\n"
327327
"\n"
328328
"Example:\n"
329329
"-red_data_type int\n"
@@ -414,12 +414,24 @@ MPI_Op get_op(MPI_Datatype type) {
414414
MPI_Datatype mpi_int = MPI_INT;
415415
MPI_Datatype mpi_float = MPI_FLOAT;
416416
MPI_Datatype mpi_double = MPI_DOUBLE;
417+
#ifdef MPIX_C_FLOAT16
418+
MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
419+
#endif
420+
#ifdef MPIX_C_BF16
421+
MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
422+
#endif
417423
size_t type_size = sizeof(MPI_Datatype);
418424

419425
if (!memcmp(&type, &mpi_char, type_size)) { MPI_Op_create(&(contig_sum<char>), 1, &op); }
420426
else if (!memcmp(&type, &mpi_int, type_size)) { MPI_Op_create(&(contig_sum<int>), 1, &op); }
421427
else if (!memcmp(&type, &mpi_float, type_size)) { MPI_Op_create(&(contig_sum<float>), 1, &op); }
422428
else if (!memcmp(&type, &mpi_double, type_size)) { MPI_Op_create(&(contig_sum<double>), 1, &op); }
429+
#ifdef MPIX_C_FLOAT16
430+
else if (!memcmp(&type, &mpi_float16, type_size)) { op = MPI_OP_NULL; fprintf(stdout, "\nWarning: contig_type doesn't supported\n"); }
431+
#endif
432+
#ifdef MPIX_C_BF16
433+
else if (!memcmp(&type, &mpi_bfloat16, type_size)) { op = MPI_OP_NULL; fprintf(stdout, "\nWarning: contig_type doesn't supported \n"); }
434+
#endif
423435

424436
return op;
425437
}
@@ -431,13 +443,25 @@ string type_to_name(MPI_Datatype type) {
431443
MPI_Datatype mpi_int = MPI_INT;
432444
MPI_Datatype mpi_float = MPI_FLOAT;
433445
MPI_Datatype mpi_double = MPI_DOUBLE;
446+
#ifdef MPIX_C_FLOAT16
447+
MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
448+
#endif
449+
#ifdef MPIX_C_BF16
450+
MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
451+
#endif
434452
size_t type_size = sizeof(MPI_Datatype);
435453

436454
if (!memcmp(&type, &mpi_byte, type_size)) { name = "MPI_BYTE"; }
437455
else if (!memcmp(&type, &mpi_char, type_size)) { name = "MPI_CHAR"; }
438456
else if (!memcmp(&type, &mpi_int, type_size)) { name = "MPI_INT"; }
439457
else if (!memcmp(&type, &mpi_float, type_size)) { name = "MPI_FLOAT"; }
440458
else if (!memcmp(&type, &mpi_double, type_size)) { name = "MPI_DOUBLE"; }
459+
#ifdef MPIX_C_FLOAT16
460+
else if (!memcmp(&type, &mpi_float16, type_size)) { name = "MPIX_C_FLOAT16"; }
461+
#endif
462+
#ifdef MPIX_C_BF16
463+
else if (!memcmp(&type, &mpi_bfloat16, type_size)) { name = "MPIX_C_BF16"; }
464+
#endif
441465

442466
return name;
443467
}
@@ -602,6 +626,16 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
602626
} else if (given_data_type == "double") {
603627
c_info.s_data_type = MPI_DOUBLE;
604628
c_info.r_data_type = MPI_DOUBLE;
629+
#ifdef MPIX_C_FLOAT16
630+
} else if (given_data_type == "float16") {
631+
c_info.s_data_type = MPIX_C_FLOAT16;
632+
c_info.r_data_type = MPIX_C_FLOAT16;
633+
#endif
634+
#ifdef MPIX_C_BF16
635+
} else if (given_data_type == "bfloat16") {
636+
c_info.s_data_type = MPIX_C_BF16;
637+
c_info.r_data_type = MPIX_C_BF16;
638+
#endif
605639
} else {
606640
output << "Invalid data_type " << given_data_type << endl;
607641
output << " Set data_type byte" << endl;
@@ -619,6 +653,14 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
619653
c_info.red_data_type = MPI_FLOAT;
620654
} else if (given_red_data_type == "double") {
621655
c_info.red_data_type = MPI_DOUBLE;
656+
#ifdef MPIX_C_FLOAT16
657+
} else if (given_red_data_type == "float16") {
658+
c_info.red_data_type = MPIX_C_FLOAT16;
659+
#endif
660+
#ifdef MPIX_C_BF16
661+
} else if (given_red_data_type == "bfloat16") {
662+
c_info.red_data_type = MPIX_C_BF16;
663+
#endif
622664
} else {
623665
output << "Invalid red_data_type " << given_red_data_type << endl;
624666
output << " Set red_data_type float" << endl;

0 commit comments

Comments
 (0)