@@ -549,6 +549,9 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
549549 ActivationType mActType = ActivationType::Relu;
550550
551551 constexpr static int64_t NUM_BUFFERS = 32 ;
552+ int64_t mNumWorkspaceBuffers = NUM_BUFFERS;
553+ int64_t mNumInputBuffers = NUM_BUFFERS;
554+ int64_t mNumGemmProfilerBuffers = NUM_BUFFERS;
552555
553556 std::array<QuantParams, NUM_BUFFERS> mQuantParams {};
554557 bool mUseLora = false ;
@@ -619,12 +622,12 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
619622
620623 if (gemm_to_profile == GemmToProfile::LAYER)
621624 {
622-
623625 mWorkspaceSize = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK ,
624626 mActType , parallelism_config, mUseLora , /* use_deepseek_fp8_block_scale=*/ false ,
625627 /* min_latency_mode=*/ false , mUsePrequantScale );
626628
627- mWorkspace = allocBuffer<char >(mWorkspaceSize * NUM_BUFFERS);
629+ mNumWorkspaceBuffers = mWorkspaceSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
630+ mWorkspace = allocBuffer<char >(mWorkspaceSize * mNumWorkspaceBuffers );
628631
629632 mExpertBias1 = nullptr ;
630633 mExpertBias2 = nullptr ;
@@ -690,9 +693,10 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
690693 mScaleProbsSize = padSize (mTotalTokens * mK );
691694 mScaleProbs = allocBuffer<float >(mScaleProbsSize * NUM_BUFFERS);
692695 mInputTensorSize = padSize (mTotalTokens * mHiddenSize );
693- mInputTensor = allocBuffer<DataType>(mInputTensorSize * NUM_BUFFERS);
696+ mNumInputBuffers = mInputTensorSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
697+ mInputTensor = allocBuffer<DataType>(mInputTensorSize * mNumInputBuffers );
694698 mFinalOutputSize = padSize (mTotalTokens * mHiddenSize );
695- mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * NUM_BUFFERS );
699+ mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * mNumInputBuffers );
696700
697701 mSourceToExpandedMapSize = padSize (mTotalTokens * mK );
698702 mSourceToExpandedMap = allocBuffer<int >(mSourceToExpandedMapSize * NUM_BUFFERS);
@@ -732,10 +736,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
732736 = std::max (mGemmProfilerWorkspaceSize , mGemmProfilerBackend .getWorkspaceSize (mTotalTokens ));
733737 }
734738
735- int64_t num_gemm_buffers = gemm_to_profile == GemmToProfile::LAYER ? 1 : NUM_BUFFERS;
736739 mGemmProfilerWorkspaceSize = padSize (mGemmProfilerWorkspaceSize );
740+ mNumGemmProfilerBuffers = mGemmProfilerWorkspaceSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
741+ mNumGemmProfilerBuffers = gemm_to_profile == GemmToProfile::LAYER ? 1 : mNumGemmProfilerBuffers ;
737742 mGemmProfilerWorkspace = mGemmProfilerWorkspaceSize > 0
738- ? allocBuffer<char >(mGemmProfilerWorkspaceSize * num_gemm_buffers )
743+ ? allocBuffer<char >(mGemmProfilerWorkspaceSize * mNumGemmProfilerBuffers )
739744 : nullptr ;
740745
741746 check_cuda_error (cudaStreamSynchronize (streamPtr->get ()));
@@ -748,7 +753,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
748753 mGemmProfilerBackend .mGemmToProfile = static_cast <GemmProfilerBackend::GemmToProfile>(gemm_to_profile);
749754 auto * expert_weights = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1 : mExpertWeight2 ;
750755 auto expert_weights_size = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size ;
751- mGemmProfilerBackend .prepare (mTotalTokens , mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * mBufferIndex ,
756+ mGemmProfilerBackend .prepare (mTotalTokens ,
757+ mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * (mBufferIndex % mNumGemmProfilerBuffers ),
752758 /* expert_weights=*/ expert_weights + expert_weights_size * mBufferIndex , streamPtr->get ());
753759 }
754760
@@ -865,7 +871,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
865871 }
866872
867873 // Profile all samples or for 1 sec
868- int const max_iters = mGemmProfilerBackend .NUM_ROUTING_SAMPLES ;
874+ int const max_iters = mGemmProfilerBackend .NUM_ROUTING_SAMPLES * 2 ;
869875 float const max_time_ms = 1000 .f ;
870876
871877 float time = 0 .f ;
@@ -974,7 +980,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
974980 }
975981 mGemmProfilerBackend .mSampleIndex = mBufferIndex % mGemmProfilerBackend .NUM_ROUTING_SAMPLES ;
976982 mGemmProfilerBackend .runProfiler (mTotalTokens , tactics,
977- mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * mBufferIndex ,
983+ mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * ( mBufferIndex % mNumGemmProfilerBuffers ) ,
978984 /* expert_weights=*/ expert_weights + expert_weights_size * mBufferIndex , streamPtr->get ());
979985 break ;
980986 }
@@ -983,26 +989,28 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
983989 auto stream = streamPtr->get ();
984990 MoeMinLatencyParams min_latency_params;
985991#ifdef USING_OSS_CUTLASS_MOE_GEMM
986- mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr , true ,
992+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * ( mBufferIndex % mNumInputBuffers ) , nullptr , true ,
987993 mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
988994 mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
989995 mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
990996 ActivationParams (mActType ), mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
991997 mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
992- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
993- mFinalOutput + mFinalOutputSize * mBufferIndex ,
998+ mHiddenSize , mInterSize , mNumExperts , mK ,
999+ mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers ),
1000+ mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers ),
9941001 mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config,
9951002 /* enable_alltoall=*/ false , mUseLora , mLoraParams [mBufferIndex ],
9961003 /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
9971004#else
998- mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr , true ,
1005+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * ( mBufferIndex % mNumInputBuffers ) , nullptr , true ,
9991006 mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
10001007 mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
10011008 mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
10021009 ActivationParams (mActType ), mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
10031010 mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
1004- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
1005- mFinalOutput + mFinalOutputSize * mBufferIndex ,
1011+ mHiddenSize , mInterSize , mNumExperts , mK ,
1012+ mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers ),
1013+ mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers ),
10061014 mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config,
10071015 /* enable_alltoall=*/ false , mUseLora , mLoraParams [mBufferIndex ],
10081016 /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
0 commit comments