@@ -312,7 +312,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
312
312
set_description (
313
313
" The argument after -data_type is a one from possible strings,\n "
314
314
" Specifying that type will be used:\n "
315
- " byte, char, int, float, double\n "
315
+ " byte, char, int, float, double, float16, bfloat16 \n "
316
316
" \n "
317
317
" Example:\n "
318
318
" -data_type char\n "
@@ -323,7 +323,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
323
323
set_description (
324
324
" The argument after -red_data_type is a one from possible strings,\n "
325
325
" Specifying that type will be used:\n "
326
- " char, int, float, double\n "
326
+ " char, int, float, double, float16, bfloat16 \n "
327
327
" \n "
328
328
" Example:\n "
329
329
" -red_data_type int\n "
@@ -414,12 +414,24 @@ MPI_Op get_op(MPI_Datatype type) {
414
414
MPI_Datatype mpi_int = MPI_INT;
415
415
MPI_Datatype mpi_float = MPI_FLOAT;
416
416
MPI_Datatype mpi_double = MPI_DOUBLE;
417
+ #ifdef MPIX_C_FLOAT16
418
+ MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
419
+ #endif
420
+ #ifdef MPIX_C_BF16
421
+ MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
422
+ #endif
417
423
size_t type_size = sizeof (MPI_Datatype);
418
424
419
425
if (!memcmp (&type, &mpi_char, type_size)) { MPI_Op_create (&(contig_sum<char >), 1 , &op); }
420
426
else if (!memcmp (&type, &mpi_int, type_size)) { MPI_Op_create (&(contig_sum<int >), 1 , &op); }
421
427
else if (!memcmp (&type, &mpi_float, type_size)) { MPI_Op_create (&(contig_sum<float >), 1 , &op); }
422
428
else if (!memcmp (&type, &mpi_double, type_size)) { MPI_Op_create (&(contig_sum<double >), 1 , &op); }
429
+ #ifdef MPIX_C_FLOAT16
430
+ else if (!memcmp (&type, &mpi_float16, type_size)) { op = MPI_OP_NULL; fprintf (stdout, " \n Warning: contig_type doesn't supported\n " ); }
431
+ #endif
432
+ #ifdef MPIX_C_BF16
433
+ else if (!memcmp (&type, &mpi_bfloat16, type_size)) { op = MPI_OP_NULL; fprintf (stdout, " \n Warning: contig_type doesn't supported \n " ); }
434
+ #endif
423
435
424
436
return op;
425
437
}
@@ -431,13 +443,25 @@ string type_to_name(MPI_Datatype type) {
431
443
MPI_Datatype mpi_int = MPI_INT;
432
444
MPI_Datatype mpi_float = MPI_FLOAT;
433
445
MPI_Datatype mpi_double = MPI_DOUBLE;
446
+ #ifdef MPIX_C_FLOAT16
447
+ MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
448
+ #endif
449
+ #ifdef MPIX_C_BF16
450
+ MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
451
+ #endif
434
452
size_t type_size = sizeof (MPI_Datatype);
435
453
436
454
if (!memcmp (&type, &mpi_byte, type_size)) { name = " MPI_BYTE" ; }
437
455
else if (!memcmp (&type, &mpi_char, type_size)) { name = " MPI_CHAR" ; }
438
456
else if (!memcmp (&type, &mpi_int, type_size)) { name = " MPI_INT" ; }
439
457
else if (!memcmp (&type, &mpi_float, type_size)) { name = " MPI_FLOAT" ; }
440
458
else if (!memcmp (&type, &mpi_double, type_size)) { name = " MPI_DOUBLE" ; }
459
+ #ifdef MPIX_C_FLOAT16
460
+ else if (!memcmp (&type, &mpi_float16, type_size)) { name = " MPIX_C_FLOAT16" ; }
461
+ #endif
462
+ #ifdef MPIX_C_BF16
463
+ else if (!memcmp (&type, &mpi_bfloat16, type_size)) { name = " MPIX_C_BF16" ; }
464
+ #endif
441
465
442
466
return name;
443
467
}
@@ -602,6 +626,16 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
602
626
} else if (given_data_type == " double" ) {
603
627
c_info.s_data_type = MPI_DOUBLE;
604
628
c_info.r_data_type = MPI_DOUBLE;
629
+ #ifdef MPIX_C_FLOAT16
630
+ } else if (given_data_type == " float16" ) {
631
+ c_info.s_data_type = MPIX_C_FLOAT16;
632
+ c_info.r_data_type = MPIX_C_FLOAT16;
633
+ #endif
634
+ #ifdef MPIX_C_BF16
635
+ } else if (given_data_type == " bfloat16" ) {
636
+ c_info.s_data_type = MPIX_C_BF16;
637
+ c_info.r_data_type = MPIX_C_BF16;
638
+ #endif
605
639
} else {
606
640
output << " Invalid data_type " << given_data_type << endl;
607
641
output << " Set data_type byte" << endl;
@@ -619,6 +653,14 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
619
653
c_info.red_data_type = MPI_FLOAT;
620
654
} else if (given_red_data_type == " double" ) {
621
655
c_info.red_data_type = MPI_DOUBLE;
656
+ #ifdef MPIX_C_FLOAT16
657
+ } else if (given_red_data_type == " float16" ) {
658
+ c_info.red_data_type = MPIX_C_FLOAT16;
659
+ #endif
660
+ #ifdef MPIX_C_BF16
661
+ } else if (given_red_data_type == " bfloat16" ) {
662
+ c_info.red_data_type = MPIX_C_BF16;
663
+ #endif
622
664
} else {
623
665
output << " Invalid red_data_type " << given_red_data_type << endl;
624
666
output << " Set red_data_type float" << endl;
0 commit comments