@@ -554,11 +554,24 @@ def target_scaled_mm_prologue_pattern(
554554 )
555555
556556 def register_nvfp4_gemm_prologue (custom_pass : PatternMatcherPass ):
557+ act_fp4_key = KeywordArg ('act_fp4' )
558+ weight_key = KeywordArg ('weight' )
559+ act_sf_key = KeywordArg ('act_sf' )
560+ weight_scale_key = KeywordArg ('weight_scale' )
561+ alpha_key = KeywordArg ('alpha' )
562+ output_dtype_key = KeywordArg ('output_dtype' )
563+ to_userbuffers_key = KeywordArg ('to_userbuffers' )
564+ backend_key = KeywordArg ('backend' )
557565 trtllm_nvfp4_gemm_default = CallFunction (
558- torch .ops .trtllm .nvfp4_gemm .default , KeywordArg ('act_fp4' ),
559- KeywordArg ('weight' ), KeywordArg ('act_sf' ),
560- KeywordArg ('weight_scale' ), KeywordArg ('alpha' ),
561- KeywordArg ('output_dtype' ))
566+ torch .ops .trtllm .nvfp4_gemm .default ,
567+ act_fp4_key ,
568+ weight_key ,
569+ act_sf_key ,
570+ weight_scale_key ,
571+ alpha_key ,
572+ output_dtype_key ,
573+ to_userbuffers = to_userbuffers_key ,
574+ backend = backend_key )
562575 ub_copy = CallFunction (torch .ops .trtllm .copy_to_userbuffers ,
563576 trtllm_nvfp4_gemm_default )
564577
@@ -569,6 +582,8 @@ def empty_nvfp4_gemm_prologue_pattern(
569582 weight_scale : torch .Tensor ,
570583 alpha : torch .Tensor ,
571584 output_dtype : torch .dtype ,
585+ to_userbuffers : bool ,
586+ backend : str ,
572587 ):
573588 return
574589
@@ -579,18 +594,45 @@ def target_nvfp4_gemm_prologue_pattern(
579594 weight_scale : torch .Tensor ,
580595 alpha : torch .Tensor ,
581596 output_dtype : torch .dtype ,
597+ to_userbuffers : bool ,
598+ backend : str ,
582599 ):
583600 nvfp4_gemm_output = torch .ops .trtllm .nvfp4_gemm (
584601 act_fp4 , weight , act_sf , weight_scale , alpha , output_dtype ,
585- True )
602+ True , backend )
586603 return nvfp4_gemm_output
587604
588- # No extra check needed as the output dtype of nvfp4_gemm has been verified when
589- # ub_copy is inserted.
605+ def extra_check (match : Match ) -> bool :
606+ # Validate backend value
607+ backend_node = match .kwargs .get ('backend' )
608+ if backend_node is None :
609+ # No backend specified, use default - OK
610+ return True
611+
612+ valid_backends = {'auto' , 'cutlass' , 'cublaslt' , 'cutedsl' }
613+
614+ # Case 1: backend is a Node with metadata
615+ if hasattr (backend_node , 'meta' ) and 'val' in backend_node .meta :
616+ backend_value = backend_node .meta ['val' ]
617+ if isinstance (backend_value , str ):
618+ return backend_value in valid_backends
619+ return False # Invalid type
620+
621+ # Case 2: backend is a constant Node in the graph
622+ if hasattr (backend_node , 'target' ):
623+ return backend_node .target in valid_backends
624+
625+ # Case 3: backend is a Python literal in kwargs
626+ if isinstance (backend_node , str ):
627+ return backend_node in valid_backends
628+
629+ # Unknown format - reject to be safe
630+ return False
631+
590632 register_replacement (
591633 empty_nvfp4_gemm_prologue_pattern ,
592634 target_nvfp4_gemm_prologue_pattern ,
593- [],
635+ [extra_check ],
594636 fwd_only ,
595637 custom_pass ,
596638 search_fn_pattern = ub_copy ,
0 commit comments