2121#include < benchmark/benchmark.h>
2222#include < kvikio/defaults.hpp>
2323
24- enum ScalingType : int64_t {
24+ namespace kvikio {
25+ enum class ScalingType : uint8_t {
2526 StrongScaling,
2627 WeakScaling,
2728};
@@ -35,21 +36,19 @@ void task_compute(std::size_t num_compute_iterations)
3536 }
3637}
3738
39+ template <ScalingType scaling_type>
3840void BM_threadpool_compute (benchmark::State& state)
3941{
40- auto num_threads = state.range (0 );
41- auto compute_bench_type = state.range (1 );
42+ auto num_threads = state.range (0 );
4243
4344 std::string label;
4445 std::size_t num_compute_tasks;
45- if (compute_bench_type == ScalingType::StrongScaling) {
46+ if constexpr (scaling_type == ScalingType::StrongScaling) {
4647 num_compute_tasks = 1'0000 ;
47- label = " strong_scaling" ;
4848 } else {
4949 num_compute_tasks = 1000 * num_threads;
50- label = " weak_scaling" ;
5150 }
52- state. SetLabel (label);
51+
5352 std::size_t const num_compute_iterations{100'000 };
5453 kvikio::defaults::set_thread_pool_nthreads (num_threads);
5554
@@ -65,17 +64,22 @@ void BM_threadpool_compute(benchmark::State& state)
6564
6665 state.counters [" threads" ] = num_threads;
6766}
67+ } // namespace kvikio
6868
6969int main (int argc, char ** argv)
7070{
7171 benchmark::Initialize (&argc, argv);
7272
73- benchmark::RegisterBenchmark (" BM_threadpool_compute:strong_scaling" , BM_threadpool_compute)
74- ->ArgsProduct ({{1 , 2 , 4 , 8 , 16 , 32 , 64 }, {ScalingType::StrongScaling}})
73+ benchmark::RegisterBenchmark (" BM_threadpool_compute:strong_scaling" ,
74+ kvikio::BM_threadpool_compute<kvikio::ScalingType::StrongScaling>)
75+ ->RangeMultiplier (2 )
76+ ->Range (1 , 64 )
7577 ->Unit (benchmark::kMillisecond );
7678
77- benchmark::RegisterBenchmark (" BM_threadpool_compute:weak_scaling" , BM_threadpool_compute)
78- ->ArgsProduct ({{1 , 2 , 4 , 8 , 16 , 32 , 64 }, {ScalingType::WeakScaling}})
79+ benchmark::RegisterBenchmark (" BM_threadpool_compute:weak_scaling" ,
80+ kvikio::BM_threadpool_compute<kvikio::ScalingType::WeakScaling>)
81+ ->RangeMultiplier (2 )
82+ ->Range (1 , 64 )
7983 ->Unit (benchmark::kMillisecond );
8084
8185 benchmark::RunSpecifiedBenchmarks ();
0 commit comments