@@ -68,50 +68,54 @@ class TrtllmGenBatchedGemmRunner
6868 int32_t configIndex) const ;
6969
7070 // Generic GEMM interface
71- void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, int32_t numTokens,
72- int32_t numBatches, int32_t maxNumCtasInBatchDim, void const * a, void const * sfA, void const * b,
73- void const * sfB, void const * perTokensSfA, void const * perTokensSfB, float const * scaleC,
74- float const * scaleGateC, float const * bias, float const * swiGluAlpha, float const * swiGluBeta,
75- float const * clampLimit, void * c, void * outSfC, int32_t const * routeMap, int32_t const * totalNumPaddedTokens,
76- int32_t const * ctaIdxXyToBatchIdx, int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas,
77- void * workspace, CUstream stream, int device, int32_t configIndex);
71+ void run (int32_t m, int32_t n, int32_t k, int32_t validM, int32_t validN, int32_t validK,
72+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
73+ void const * a, void const * sfA, void const * b, void const * sfB, void const * perTokensSfA,
74+ void const * perTokensSfB, float const * scaleC, float const * scaleGateC, float const * bias,
75+ float const * swiGluAlpha, float const * swiGluBeta, float const * clampLimit, void * c, void * outSfC,
76+ int32_t const * routeMap, int32_t const * totalNumPaddedTokens, int32_t const * ctaIdxXyToBatchIdx,
77+ int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas, void * workspace, CUstream stream,
78+ int device, int32_t configIndex);
7879
7980 // Block-scaling GEMM
8081 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * sfA,
8182 void const * b, void const * sfB, void * c, void * outSfC, void * workspace, CUstream stream, int device,
82- int32_t configIndex);
83+ int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
8384
8485 // Block-scaling GEMM with SwiGLU activation
8586 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * sfA,
8687 void const * b, void const * sfB, float const * bias, float const * swiGluAlpha, float const * swiGluBeta,
8788 float const * clampLimit, void * c, void * outSfC, void * workspace, CUstream stream, int device,
88- int32_t configIndex);
89+ int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
8990
9091 // FP8 per-tensor scaling GEMM
9192 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * b,
9293 float const * scaleC, float const * scaleGateC, void * c, void * workspace, CUstream stream, int device,
93- int32_t configIndex);
94+ int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
9495
9596 // Get the list of configs that passed the validation based on the constructor options
9697 [[nodiscard]] std::vector<int64_t > getPassingConfigIndices () const
9798 {
9899 return mPassingConfigIndices ;
99100 }
100101
102+ // Get the kernel name from the config index
103+ [[nodiscard]] std::string getKernelNameFromConfigIndex (int32_t configIndex) const ;
104+
101105 // Get the list of config indices that are valid for the given problem shape
102106 [[nodiscard]] std::vector<int64_t > getValidConfigIndices (int32_t m, int32_t n, int32_t k,
103- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
104- int32_t maxNumCtasInBatchDim ) const ;
107+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
108+ int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
105109
106110 // Get a default config index that is valid for the given problem shape
107111 // This will be used as the fallback config if using auto-tuning
108112 [[nodiscard]] int64_t getDefaultValidConfigIndex (int32_t m, int32_t n, int32_t k,
109- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
110- int32_t maxNumCtasInBatchDim ) const ;
113+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
114+ int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
111115
112116 [[nodiscard]] bool isValidConfigIndex (int32_t configIndex, int32_t m, int32_t n, int32_t k,
113- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
114- int32_t maxNumCtasInBatchDim ) const ;
117+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
118+ int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
115119
116120private:
117121 void selectGemmConfig (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, int32_t numTokens,
0 commit comments