@@ -217,7 +217,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32)
217217 /* numExperts=*/ 32 , /* topK=*/ 8 ,
218218 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
219219 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
220- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
220+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
221221 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
222222 this ->runTest (param);
223223};
@@ -228,7 +228,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72)
228228 /* numExperts=*/ 72 , /* topK=*/ 6 ,
229229 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
230230 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
231- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
231+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
232232 /* nGroup*/ 1 , /* topkGroup*/ 1 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
233233 this ->runTest (param);
234234};
@@ -239,7 +239,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
239239 /* numExperts=*/ 384 , /* topK=*/ 8 ,
240240 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
241241 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
242- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
242+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
243243 /* nGroup*/ 1 , /* topkGroup*/ 1 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
244244 this ->runTest (param);
245245};
@@ -250,7 +250,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
250250 /* numExperts=*/ 256 , /* topK=*/ 8 ,
251251 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
252252 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
253- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
253+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
254254 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
255255 this ->runTest (param);
256256};
@@ -261,7 +261,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
261261 /* numExperts=*/ 256 , /* topK=*/ 8 ,
262262 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 192 ,
263263 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
264- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
264+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput= */ true ,
265265 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
266266 this ->runTest (param);
267267};
@@ -272,7 +272,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
272272 /* numExperts=*/ 384 , /* topK=*/ 8 ,
273273 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
274274 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
275- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
275+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput= */ false ,
276276 /* nGroup*/ 1 , /* topkGroup*/ 1 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
277277 this ->runTest (param);
278278};
@@ -283,7 +283,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
283283 /* numExperts=*/ 256 , /* topK=*/ 8 ,
284284 /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
285285 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
286- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
286+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
287287 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
288288 this ->runTest (param);
289289};
@@ -294,7 +294,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
294294 /* numExperts=*/ 256 , /* topK=*/ 8 ,
295295 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
296296 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
297- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
297+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
298298 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
299299 this ->runTest (param);
300300};
@@ -305,7 +305,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
305305 /* numExperts=*/ 384 , /* topK=*/ 8 ,
306306 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
307307 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
308- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
308+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
309309 /* nGroup*/ 1 , /* topkGroup*/ 1 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
310310 this ->runTest (param);
311311};
@@ -316,7 +316,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
316316 /* numExperts=*/ 256 , /* topK=*/ 8 ,
317317 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
318318 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
319- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
319+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput= */ true ,
320320 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
321321 this ->runTest (param);
322322};
@@ -327,7 +327,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
327327 /* numExperts=*/ 384 , /* topK=*/ 8 ,
328328 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
329329 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
330- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
330+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
331331 /* nGroup*/ 1 , /* topkGroup*/ 1 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
332332 this ->runTest (param);
333333};
@@ -338,7 +338,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
338338 /* numExperts=*/ 256 , /* topK=*/ 2 ,
339339 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
340340 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
341- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
341+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
342342 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
343343 this ->runTest (param);
344344};
@@ -349,7 +349,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
349349 /* numExperts=*/ 256 , /* topK=*/ 2 ,
350350 /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
351351 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
352- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
352+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
353353 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
354354 this ->runTest (param);
355355};
@@ -360,7 +360,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
360360 /* numExperts=*/ 256 , /* topK=*/ 2 ,
361361 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
362362 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
363- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
363+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
364364 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
365365 this ->runTest (param);
366366};
@@ -371,7 +371,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8)
371371 /* numExperts=*/ 32 , /* topK=*/ 8 ,
372372 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
373373 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
374- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
374+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput= */ true ,
375375 /* nGroup*/ 8 , /* topkGroup*/ 4 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 10 );
376376 this ->runTest (param);
377377};
0 commit comments