-
Notifications
You must be signed in to change notification settings - Fork 76
Description
Describe the issue
w8a8_block_fp8_matmul_triton.py
run given file like below
python w8a8_block_fp8_matmul_triton.py
use xpu torch 2.9.0 env, for this pytorch-triton-xpu 3.5.0 is default
changing alone pytorch-triton-xpu to 3.4.0 fixes performance regression
regression should not be fixed by tensor descriptors. same code should work with similar perf on all versions
BMG B580 12GB card results
with 3.5
running main
running 48 shapes for 10 times
i=0, avg time - 9.7769412 ms, for input [7, 7168], weight [2112, 7168]
i=1, avg time - 7.273006 ms, for input [7, 1536], weight [24576, 1536]
i=2, avg time - 2.6029328 ms, for input [7, 512], weight [32768, 512]
i=3, avg time - 30.016615199999997 ms, for input [7, 16384], weight [7168, 16384]
i=4, avg time - 38.2207904 ms, for input [7, 7168], weight [36864, 7168]
i=5, avg time - 33.846961199999996 ms, for input [7, 18432], weight [7168, 18432]
i=6, avg time - 10.565994399999997 ms, for input [7, 7168], weight [4096, 7168]
i=7, avg time - 3.6795356000000004 ms, for input [7, 2048], weight [7168, 2048]
i=8, avg time - 70.88357640000001 ms, for input [2048, 7168], weight [2112, 7168]
i=9, avg time - 156.15292680000002 ms, for input [2048, 1536], weight [24576, 1536]
i=10, avg time - 71.05428199999999 ms, for input [2048, 512], weight [32768, 512]
i=11, avg time - 493.77448639999994 ms, for input [2048, 16384], weight [7168, 16384]
i=12, avg time - 1082.5891648000002 ms, for input [2048, 7168], weight [36864, 7168]
i=13, avg time - 555.2376075999999 ms, for input [2048, 18432], weight [7168, 18432]
i=14, avg time - 125.94295480000001 ms, for input [2048, 7168], weight [4096, 7168]
i=15, avg time - 61.3912052 ms, for input [2048, 2048], weight [7168, 2048]
i=16, avg time - 72.8357812 ms, for input [2029, 7168], weight [2112, 7168]
i=17, avg time - 159.0087512 ms, for input [2029, 1536], weight [24576, 1536]
i=18, avg time - 72.3626436 ms, for input [2029, 512], weight [32768, 512]
i=19, avg time - 502.07314039999994 ms, for input [2029, 16384], weight [7168, 16384]
i=20, avg time - 1102.4178672 ms, for input [2029, 7168], weight [36864, 7168]
i=21, avg time - 564.2669812000001 ms, for input [2029, 18432], weight [7168, 18432]
i=22, avg time - 129.5504548 ms, for input [2029, 7168], weight [4096, 7168]
i=23, avg time - 62.6102412 ms, for input [2029, 2048], weight [7168, 2048]
i=24, avg time - 2.8184468 ms, for input [1, 7168], weight [2112, 7168]
i=25, avg time - 2.3672895999999994 ms, for input [1, 1536], weight [24576, 1536]
i=26, avg time - 0.9661028 ms, for input [1, 512], weight [32768, 512]
i=27, avg time - 9.3407808 ms, for input [1, 16384], weight [7168, 16384]
i=28, avg time - 12.0774524 ms, for input [1, 7168], weight [36864, 7168]
i=29, avg time - 10.8239404 ms, for input [1, 18432], weight [7168, 18432]
i=30, avg time - 3.3447128 ms, for input [1, 7168], weight [4096, 7168]
i=31, avg time - 1.2411151999999999 ms, for input [1, 2048], weight [7168, 2048]
i=32, avg time - 8.798514399999998 ms, for input [16, 7168], weight [2112, 7168]
i=33, avg time - 7.170436 ms, for input [16, 1536], weight [24576, 1536]
i=34, avg time - 2.56412 ms, for input [16, 512], weight [32768, 512]
i=35, avg time - 29.543997599999994 ms, for input [16, 16384], weight [7168, 16384]
i=36, avg time - 37.246804399999995 ms, for input [16, 7168], weight [36864, 7168]
i=37, avg time - 32.767134399999996 ms, for input [16, 18432], weight [7168, 18432]
i=38, avg time - 10.500547199999998 ms, for input [16, 7168], weight [4096, 7168]
i=39, avg time - 3.6867636 ms, for input [16, 2048], weight [7168, 2048]
i=40, avg time - 36.7310268 ms, for input [1024, 7168], weight [2112, 7168]
i=41, avg time - 78.567944 ms, for input [1024, 1536], weight [24576, 1536]
i=42, avg time - 35.7436144 ms, for input [1024, 512], weight [32768, 512]
i=43, avg time - 253.3582012 ms, for input [1024, 16384], weight [7168, 16384]
i=44, avg time - 545.1474912 ms, for input [1024, 7168], weight [36864, 7168]
i=45, avg time - 284.55867440000003 ms, for input [1024, 18432], weight [7168, 18432]
i=46, avg time - 67.59068679999999 ms, for input [1024, 7168], weight [4096, 7168]
i=47, avg time - 31.732818 ms, for input [1024, 2048], weight [7168, 2048]
with 3.4.0
`
running main
running 48 shapes for 10 times
i=0, avg time - 3.6866492 ms, for input [7, 7168], weight [2112, 7168]
i=1, avg time - 2.6210964 ms, for input [7, 1536], weight [24576, 1536]
i=2, avg time - 1.0747048 ms, for input [7, 512], weight [32768, 512]
i=3, avg time - 12.1522544 ms, for input [7, 16384], weight [7168, 16384]
i=4, avg time - 14.6122028 ms, for input [7, 7168], weight [36864, 7168]
i=5, avg time - 9.622121600000002 ms, for input [7, 18432], weight [7168, 18432]
i=6, avg time - 3.3113808 ms, for input [7, 7168], weight [4096, 7168]
i=7, avg time - 1.1259092000000002 ms, for input [7, 2048], weight [7168, 2048]
i=8, avg time - 27.478443199999997 ms, for input [2048, 7168], weight [2112, 7168]
i=9, avg time - 64.0359096 ms, for input [2048, 1536], weight [24576, 1536]
i=10, avg time - 29.911211199999997 ms, for input [2048, 512], weight [32768, 512]
i=11, avg time - 202.36799959999996 ms, for input [2048, 16384], weight [7168, 16384]
i=12, avg time - 436.147114 ms, for input [2048, 7168], weight [36864, 7168]
i=13, avg time - 220.9043876 ms, for input [2048, 18432], weight [7168, 18432]
i=14, avg time - 50.13476 ms, for input [2048, 7168], weight [4096, 7168]
i=15, avg time - 24.426875199999998 ms, for input [2048, 2048], weight [7168, 2048]
i=16, avg time - 27.534993199999995 ms, for input [2029, 7168], weight [2112, 7168]
i=17, avg time - 63.8328652 ms, for input [2029, 1536], weight [24576, 1536]
i=18, avg time - 30.010848400000004 ms, for input [2029, 512], weight [32768, 512]
i=19, avg time - 203.45114919999997 ms, for input [2029, 16384], weight [7168, 16384]
i=20, avg time - 436.2289723999999 ms, for input [2029, 7168], weight [36864, 7168]
i=21, avg time - 221.38158120000003 ms, for input [2029, 18432], weight [7168, 18432]
i=22, avg time - 49.808824 ms, for input [2029, 7168], weight [4096, 7168]
i=23, avg time - 24.528337599999997 ms, for input [2029, 2048], weight [7168, 2048]
i=24, avg time - 3.1782608 ms, for input [1, 7168], weight [2112, 7168]
i=25, avg time - 1.7648435999999996 ms, for input [1, 1536], weight [24576, 1536]
i=26, avg time - 0.7653931999999999 ms, for input [1, 512], weight [32768, 512]
i=27, avg time - 6.4618684 ms, for input [1, 16384], weight [7168, 16384]
i=28, avg time - 10.9278364 ms, for input [1, 7168], weight [36864, 7168]
i=29, avg time - 6.9387968 ms, for input [1, 18432], weight [7168, 18432]
i=30, avg time - 2.4846172 ms, for input [1, 7168], weight [4096, 7168]
i=31, avg time - 0.8120631999999999 ms, for input [1, 2048], weight [7168, 2048]
i=32, avg time - 3.0808336 ms, for input [16, 7168], weight [2112, 7168]
i=33, avg time - 2.6114816 ms, for input [16, 1536], weight [24576, 1536]
i=34, avg time - 1.0760672 ms, for input [16, 512], weight [32768, 512]
i=35, avg time - 8.575273200000002 ms, for input [16, 16384], weight [7168, 16384]
i=36, avg time - 15.717702 ms, for input [16, 7168], weight [36864, 7168]
i=37, avg time - 9.6841316 ms, for input [16, 18432], weight [7168, 18432]
i=38, avg time - 3.3002684 ms, for input [16, 7168], weight [4096, 7168]
i=39, avg time - 1.1416548 ms, for input [16, 2048], weight [7168, 2048]
i=40, avg time - 14.5202772 ms, for input [1024, 7168], weight [2112, 7168]
i=41, avg time - 31.759488799999996 ms, for input [1024, 1536], weight [24576, 1536]
i=42, avg time - 14.884422800000001 ms, for input [1024, 512], weight [32768, 512]
i=43, avg time - 103.747436 ms, for input [1024, 16384], weight [7168, 16384]
i=44, avg time - 218.39304759999996 ms, for input [1024, 7168], weight [36864, 7168]
i=45, avg time - 113.50453399999999 ms, for input [1024, 18432], weight [7168, 18432]
i=46, avg time - 25.662197599999995 ms, for input [1024, 7168], weight [4096, 7168]
i=47, avg time - 12.524044 ms, for input [1024, 2048], weight [7168, 2048]
Environment details
Triton: 3.5.0
GPU: B580 12GB