@@ -69,6 +69,7 @@ def scaled_dot_product_attention(
69
69
is_causal = True
70
70
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
71
71
use_fp32_acc = kwargs .get ("use_fp32_acc" , False )
72
+ use_fp8_quantize = kwargs .get ("use_fp8_quantize" , True )
72
73
query_dtype = query .dtype
73
74
74
75
if scale is None :
@@ -97,6 +98,30 @@ def scaled_dot_product_attention(
97
98
key ,
98
99
scale ,
99
100
)
101
+ # fixed value for test
102
+ amax = torch .tensor ([0.6562 ])
103
+ if use_fp8_quantize :
104
+ key = impl .quantize .quantize (
105
+ ctx ,
106
+ target ,
107
+ SourceIR .ATEN ,
108
+ name ,
109
+ key ,
110
+ amax ,
111
+ 8 ,
112
+ 4 ,
113
+ )
114
+
115
+ query = impl .quantize .quantize (
116
+ ctx ,
117
+ target ,
118
+ SourceIR .ATEN ,
119
+ name ,
120
+ query ,
121
+ amax ,
122
+ 8 ,
123
+ 4 ,
124
+ )
100
125
101
126
if use_fp32_acc and query_dtype == trt .float16 :
102
127
query = cast_trt_tensor (
@@ -173,6 +198,29 @@ def scaled_dot_product_attention(
173
198
softmax = impl .normalization .softmax (
174
199
ctx , target , source_ir , name + "_softmax" , scaled_add_attn_bias , - 1 , False
175
200
)
201
+ if use_fp8_quantize :
202
+ softmax = impl .quantize .quantize (
203
+ ctx ,
204
+ target ,
205
+ SourceIR .ATEN ,
206
+ name ,
207
+ softmax ,
208
+ amax ,
209
+ 8 ,
210
+ 4 ,
211
+ )
212
+
213
+ value = impl .quantize .quantize (
214
+ ctx ,
215
+ target ,
216
+ SourceIR .ATEN ,
217
+ name ,
218
+ value ,
219
+ amax ,
220
+ 8 ,
221
+ 4 ,
222
+ )
223
+
176
224
if use_fp32_acc :
177
225
softmax = cast_trt_tensor (
178
226
ctx , softmax , trt .float32 , name + "_softmax_cast_to_fp32" , target , source_ir
@@ -188,9 +236,21 @@ def scaled_dot_product_attention(
188
236
softmax ,
189
237
value ,
190
238
)
239
+
191
240
if use_fp32_acc :
192
241
out = cast_trt_tensor (
193
242
ctx , out , query_dtype , name + "_out_cast_to_fp16" , target , source_ir
194
243
)
244
+ if use_fp8_quantize :
245
+ out = impl .quantize .quantize (
246
+ ctx ,
247
+ target ,
248
+ SourceIR .ATEN ,
249
+ name ,
250
+ out ,
251
+ amax ,
252
+ 8 ,
253
+ 4 ,
254
+ )
195
255
196
256
return out
0 commit comments