@@ -596,95 +596,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
596596 dtype_in = x .dtype
597597 dtype_w = self .gate_proj .weight .dtype
598598
599- # Allocate workspace: Needed size = batch_seq * hidden * sizeof(half)
600- # We use half workspace for all kernels currently (internal precision)
601- workspace_size = batch_seq * hidden * 2
599+ # Allocate workspace if needed (unused currently but kept for ABI)
600+ workspace = torch .empty (0 , device = x .device )
601+ output = torch .empty_like (x_flat )
602+ stream = torch .cuda .current_stream ().cuda_stream
602603
603604 # Use specific implementation based on types
604605 # 1. FP16 (Standard)
605606 if dtype_in == torch .float16 and dtype_w == torch .float16 :
606- # Alloc workspace
607- workspace = torch .empty (batch_seq , hidden , dtype = torch .float16 , device = x .device )
608- output = torch .empty_like (x_flat )
609-
610- stream = torch .cuda .current_stream ().cuda_stream
611-
612- # Weights are [d_out, d_in] (Row Major in PyTorch) -> Transpose to Col Major for CUDA?
613- # CUDA expects Col-Major [HIDDEN, INTER].
614- # PyTorch Linear weight is [out_features, in_features].
615- # Gate: [INTER, HIDDEN]. Transposed -> [HIDDEN, INTER] (Col Major compatible if we treat as Row Major?)
616- # Wait. cuBLAS Col Major means A[i + j*LDA].
617- # PyTorch Matrix is Row Major in memory.
618- # If we pass PyTorch pointer to CUDA expecting Col Major, we effectively transpose.
619- # W_gate (PyTorch) is [INTER, HIDDEN].
620- # CUDA expects W_gate ptr.
621- # If CUDA interprets as Col Major [HIDDEN, INTER]:
622- # Element (row=k, col=i) -> index k + i*HIDDEN.
623- # PyTorch [INTER, HIDDEN] element (row=i, col=k) -> index i*HIDDEN + k.
624- # This matches! k + i*HIDDEN == i*HIDDEN + k!
625- # So PyTorch [INTER, HIDDEN] Row Major == CUDA [HIDDEN, INTER] Col Major.
626- # Correct.
627-
628- _transformer_ops_lib .fused_mlp_block_launcher_fp16 (
629- ctypes .c_void_p (x_flat .data_ptr ()),
630- ctypes .c_void_p (self .norm_weight .data_ptr ()),
631- ctypes .c_void_p (self .gate_proj .weight .data_ptr ()),
632- ctypes .c_void_p (self .up_proj .weight .data_ptr ()),
633- ctypes .c_void_p (self .down_proj .weight .data_ptr ()),
634- ctypes .c_void_p (output .data_ptr ()),
635- ctypes .c_void_p (workspace .data_ptr ()),
636- ctypes .c_int (batch_seq ),
637- ctypes .c_int (self .hidden_size ),
638- ctypes .c_int (self .intermediate_size ),
639- ctypes .c_float (self .eps ),
640- ctypes .c_void_p (stream )
607+ _transformer_ops_lib .fused_mlp_norm_gemm_launcher_fp16 (
608+ ctypes .c_void_p (x_flat .data_ptr ()),
609+ ctypes .c_void_p (self .norm_weight .data .data_ptr ()),
610+ ctypes .c_void_p (self .gate_proj .weight .data .data_ptr ()),
611+ ctypes .c_void_p (self .up_proj .weight .data .data_ptr ()),
612+ ctypes .c_void_p (self .down_proj .weight .data .data_ptr ()),
613+ ctypes .c_void_p (output .data_ptr ()),
614+ ctypes .c_void_p (workspace .data_ptr ()),
615+ ctypes .c_int (batch_seq ),
616+ ctypes .c_int (self .hidden_size ),
617+ ctypes .c_int (self .intermediate_size ),
618+ ctypes .c_float (self .eps ),
619+ ctypes .c_void_p (stream )
641620 )
642621
643622 # 2. FP32 (Float)
644623 elif dtype_in == torch .float32 and dtype_w == torch .float32 :
645- # Alloc workspace (float for internal buffer? No, kernel uses half internally for smem,
646- # but generic launcher might expect float workspace if it reuses it?)
647- # Our launcher signature: float* workspace.
648- workspace = torch .empty (batch_seq , hidden , dtype = torch .float32 , device = x .device )
649- output = torch .empty_like (x_flat )
650- stream = torch .cuda .current_stream ().cuda_stream
651-
652- _transformer_ops_lib .fused_mlp_block_launcher_fp32 (
653- ctypes .c_void_p (x_flat .data_ptr ()),
654- ctypes .c_void_p (self .norm_weight .data_ptr ()),
655- ctypes .c_void_p (self .gate_proj .weight .data_ptr ()),
656- ctypes .c_void_p (self .up_proj .weight .data_ptr ()),
657- ctypes .c_void_p (self .down_proj .weight .data_ptr ()),
658- ctypes .c_void_p (output .data_ptr ()),
659- ctypes .c_void_p (workspace .data_ptr ()), # passed as float*
660- ctypes .c_int (batch_seq ),
661- ctypes .c_int (self .hidden_size ),
662- ctypes .c_int (self .intermediate_size ),
663- ctypes .c_float (self .eps ),
664- ctypes .c_void_p (stream )
665- )
624+ # TODO: Implement FP32 version of fused_mlp_norm_gemm if needed
625+ # For now fallback or use existing
626+ return self ._pytorch_fallback (x )
666627
667628 # 3. W8A16 (FP16 Input, uint8 Weight)
668629 elif dtype_in == torch .float16 and dtype_w == torch .uint8 :
669- workspace = torch .empty (batch_seq , hidden , dtype = torch .float16 , device = x .device )
670- output = torch .empty_like (x_flat )
671- stream = torch .cuda .current_stream ().cuda_stream
630+ # TODO: Implement W8A16 version
631+ return self ._pytorch_fallback (x )
672632
673- _transformer_ops_lib .fused_mlp_block_launcher_w8a16 (
674- ctypes .c_void_p (x_flat .data_ptr ()),
675- ctypes .c_void_p (self .norm_weight .data_ptr ()),
676- ctypes .c_void_p (self .gate_proj .weight .data_ptr ()),
677- ctypes .c_void_p (self .up_proj .weight .data_ptr ()),
678- ctypes .c_void_p (self .down_proj .weight .data_ptr ()),
679- ctypes .c_void_p (output .data_ptr ()),
680- ctypes .c_void_p (workspace .data_ptr ()),
681- ctypes .c_int (batch_seq ),
682- ctypes .c_int (self .hidden_size ),
683- ctypes .c_int (self .intermediate_size ),
684- ctypes .c_float (self .eps ),
685- ctypes .c_void_p (stream )
686- )
687-
688633 else :
689634 # Mismatch or unsupported -> Fallback
690635 return self ._pytorch_fallback (x )
0 commit comments