Skip to content

Commit 120cd3a

Browse files
committed
Refactor fused MLP block to use new CUDA kernel
Replaces the FP16 fused MLP block launcher with a new fused_mlp_norm_gemm kernel that integrates RMSNorm and GEMM for improved efficiency. Updates the CUDA implementation to fuse normalization and matrix multiplications, adds vectorized memory operations, and improves shared memory usage. FP32 and W8A16 paths now fallback to PyTorch. Removes test_triton_ops.py from the test runner.
1 parent e59f08e commit 120cd3a

3 files changed

Lines changed: 224 additions & 123 deletions

File tree

Src/Main_Scripts/core/cuda_opt_wrapper.py

Lines changed: 22 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -596,95 +596,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
596596
dtype_in = x.dtype
597597
dtype_w = self.gate_proj.weight.dtype
598598

599-
# Allocate workspace: Needed size = batch_seq * hidden * sizeof(half)
600-
# We use half workspace for all kernels currently (internal precision)
601-
workspace_size = batch_seq * hidden * 2
599+
# Allocate workspace if needed (unused currently but kept for ABI)
600+
workspace = torch.empty(0, device=x.device)
601+
output = torch.empty_like(x_flat)
602+
stream = torch.cuda.current_stream().cuda_stream
602603

603604
# Use specific implementation based on types
604605
# 1. FP16 (Standard)
605606
if dtype_in == torch.float16 and dtype_w == torch.float16:
606-
# Alloc workspace
607-
workspace = torch.empty(batch_seq, hidden, dtype=torch.float16, device=x.device)
608-
output = torch.empty_like(x_flat)
609-
610-
stream = torch.cuda.current_stream().cuda_stream
611-
612-
# Weights are [d_out, d_in] (Row Major in PyTorch) -> Transpose to Col Major for CUDA?
613-
# CUDA expects Col-Major [HIDDEN, INTER].
614-
# PyTorch Linear weight is [out_features, in_features].
615-
# Gate: [INTER, HIDDEN]. Transposed -> [HIDDEN, INTER] (Col Major compatible if we treat as Row Major?)
616-
# Wait. cuBLAS Col Major means A[i + j*LDA].
617-
# PyTorch Matrix is Row Major in memory.
618-
# If we pass PyTorch pointer to CUDA expecting Col Major, we effectively transpose.
619-
# W_gate (PyTorch) is [INTER, HIDDEN].
620-
# CUDA expects W_gate ptr.
621-
# If CUDA interprets as Col Major [HIDDEN, INTER]:
622-
# Element (row=k, col=i) -> index k + i*HIDDEN.
623-
# PyTorch [INTER, HIDDEN] element (row=i, col=k) -> index i*HIDDEN + k.
624-
# This matches! k + i*HIDDEN == i*HIDDEN + k!
625-
# So PyTorch [INTER, HIDDEN] Row Major == CUDA [HIDDEN, INTER] Col Major.
626-
# Correct.
627-
628-
_transformer_ops_lib.fused_mlp_block_launcher_fp16(
629-
ctypes.c_void_p(x_flat.data_ptr()),
630-
ctypes.c_void_p(self.norm_weight.data_ptr()),
631-
ctypes.c_void_p(self.gate_proj.weight.data_ptr()),
632-
ctypes.c_void_p(self.up_proj.weight.data_ptr()),
633-
ctypes.c_void_p(self.down_proj.weight.data_ptr()),
634-
ctypes.c_void_p(output.data_ptr()),
635-
ctypes.c_void_p(workspace.data_ptr()),
636-
ctypes.c_int(batch_seq),
637-
ctypes.c_int(self.hidden_size),
638-
ctypes.c_int(self.intermediate_size),
639-
ctypes.c_float(self.eps),
640-
ctypes.c_void_p(stream)
607+
_transformer_ops_lib.fused_mlp_norm_gemm_launcher_fp16(
608+
ctypes.c_void_p(x_flat.data_ptr()),
609+
ctypes.c_void_p(self.norm_weight.data.data_ptr()),
610+
ctypes.c_void_p(self.gate_proj.weight.data.data_ptr()),
611+
ctypes.c_void_p(self.up_proj.weight.data.data_ptr()),
612+
ctypes.c_void_p(self.down_proj.weight.data.data_ptr()),
613+
ctypes.c_void_p(output.data_ptr()),
614+
ctypes.c_void_p(workspace.data_ptr()),
615+
ctypes.c_int(batch_seq),
616+
ctypes.c_int(self.hidden_size),
617+
ctypes.c_int(self.intermediate_size),
618+
ctypes.c_float(self.eps),
619+
ctypes.c_void_p(stream)
641620
)
642621

643622
# 2. FP32 (Float)
644623
elif dtype_in == torch.float32 and dtype_w == torch.float32:
645-
# Alloc workspace (float for internal buffer? No, kernel uses half internally for smem,
646-
# but generic launcher might expect float workspace if it reuses it?)
647-
# Our launcher signature: float* workspace.
648-
workspace = torch.empty(batch_seq, hidden, dtype=torch.float32, device=x.device)
649-
output = torch.empty_like(x_flat)
650-
stream = torch.cuda.current_stream().cuda_stream
651-
652-
_transformer_ops_lib.fused_mlp_block_launcher_fp32(
653-
ctypes.c_void_p(x_flat.data_ptr()),
654-
ctypes.c_void_p(self.norm_weight.data_ptr()),
655-
ctypes.c_void_p(self.gate_proj.weight.data_ptr()),
656-
ctypes.c_void_p(self.up_proj.weight.data_ptr()),
657-
ctypes.c_void_p(self.down_proj.weight.data_ptr()),
658-
ctypes.c_void_p(output.data_ptr()),
659-
ctypes.c_void_p(workspace.data_ptr()), # passed as float*
660-
ctypes.c_int(batch_seq),
661-
ctypes.c_int(self.hidden_size),
662-
ctypes.c_int(self.intermediate_size),
663-
ctypes.c_float(self.eps),
664-
ctypes.c_void_p(stream)
665-
)
624+
# TODO: Implement FP32 version of fused_mlp_norm_gemm if needed
625+
# For now fallback or use existing
626+
return self._pytorch_fallback(x)
666627

667628
# 3. W8A16 (FP16 Input, uint8 Weight)
668629
elif dtype_in == torch.float16 and dtype_w == torch.uint8:
669-
workspace = torch.empty(batch_seq, hidden, dtype=torch.float16, device=x.device)
670-
output = torch.empty_like(x_flat)
671-
stream = torch.cuda.current_stream().cuda_stream
630+
# TODO: Implement W8A16 version
631+
return self._pytorch_fallback(x)
672632

673-
_transformer_ops_lib.fused_mlp_block_launcher_w8a16(
674-
ctypes.c_void_p(x_flat.data_ptr()),
675-
ctypes.c_void_p(self.norm_weight.data_ptr()),
676-
ctypes.c_void_p(self.gate_proj.weight.data_ptr()),
677-
ctypes.c_void_p(self.up_proj.weight.data_ptr()),
678-
ctypes.c_void_p(self.down_proj.weight.data_ptr()),
679-
ctypes.c_void_p(output.data_ptr()),
680-
ctypes.c_void_p(workspace.data_ptr()),
681-
ctypes.c_int(batch_seq),
682-
ctypes.c_int(self.hidden_size),
683-
ctypes.c_int(self.intermediate_size),
684-
ctypes.c_float(self.eps),
685-
ctypes.c_void_p(stream)
686-
)
687-
688633
else:
689634
# Mismatch or unsupported -> Fallback
690635
return self._pytorch_fallback(x)

0 commit comments

Comments
 (0)