Skip to content

Commit 62cb0f2

Browse files
committed
chore: fp8 quantization in MHA
1 parent df34e8b commit 62cb0f2

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def scaled_dot_product_attention(
6969
is_causal = True
7070
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
7171
use_fp32_acc = kwargs.get("use_fp32_acc", False)
72+
use_fp8_quantize = kwargs.get("use_fp8_quantize", True)
7273
query_dtype = query.dtype
7374

7475
if scale is None:
@@ -97,6 +98,30 @@ def scaled_dot_product_attention(
9798
key,
9899
scale,
99100
)
101+
# fixed value for test
102+
amax = torch.tensor([0.6562])
103+
if use_fp8_quantize:
104+
key = impl.quantize.quantize(
105+
ctx,
106+
target,
107+
SourceIR.ATEN,
108+
name,
109+
key,
110+
amax,
111+
8,
112+
4,
113+
)
114+
115+
query = impl.quantize.quantize(
116+
ctx,
117+
target,
118+
SourceIR.ATEN,
119+
name,
120+
query,
121+
amax,
122+
8,
123+
4,
124+
)
100125

101126
if use_fp32_acc and query_dtype == trt.float16:
102127
query = cast_trt_tensor(
@@ -173,6 +198,29 @@ def scaled_dot_product_attention(
173198
softmax = impl.normalization.softmax(
174199
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
175200
)
201+
if use_fp8_quantize:
202+
softmax = impl.quantize.quantize(
203+
ctx,
204+
target,
205+
SourceIR.ATEN,
206+
name,
207+
softmax,
208+
amax,
209+
8,
210+
4,
211+
)
212+
213+
value = impl.quantize.quantize(
214+
ctx,
215+
target,
216+
SourceIR.ATEN,
217+
name,
218+
value,
219+
amax,
220+
8,
221+
4,
222+
)
223+
176224
if use_fp32_acc:
177225
softmax = cast_trt_tensor(
178226
ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir
@@ -188,9 +236,21 @@ def scaled_dot_product_attention(
188236
softmax,
189237
value,
190238
)
239+
191240
if use_fp32_acc:
192241
out = cast_trt_tensor(
193242
ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir
194243
)
244+
if use_fp8_quantize:
245+
out = impl.quantize.quantize(
246+
ctx,
247+
target,
248+
SourceIR.ATEN,
249+
name,
250+
out,
251+
amax,
252+
8,
253+
4,
254+
)
195255

196256
return out

0 commit comments

Comments
 (0)