Triton & 九齿 2024 冬季作业 #9
voltjia
announced in
Announcements
Replies: 1 comment 1 reply
-
感觉九齿的符号定义类似于Sympy,或者说早年的静态图神经网络框架的网络定义,就是符号运算后还是一个符号,编译后有了真实数据后会带入编译后的图进行运算(因为现在才继续看九齿的课,所以没能及时在在线直播的评论区吐槽,就转Github讨论区了,若不合时宜可直接删除此评论,就是Sympy的例子更加直观,可体验,作为示范可以加深理解) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
本作业的目标是让学员实现一个基于 Triton 的注意力计算内核,并确保它与 PyTorch 中的
scaled_dot_product_attention
函数的输出一致。具体来说,作业要求实现一个带有is_causal
参数的注意力计算内核,该参数控制是否使用因果注意力。你需要通过 Triton 实现这一功能,并验证其正确性。作业内容
1. 理解注意力机制
在注意力机制中,给定查询(Query)、键(Key)和值(Value),计算过程通常为:
其中,因果注意力(Causal Attention)通过掩蔽未来时间步(确保当前位置只与之前的位置进行交互)来实现。
PyTorch 中的
scaled_dot_product_attention
实现了该机制。你将基于 Triton 实现相同的计算。感兴趣的同学可以阅读 Attention is All You Need。
2. 实现步骤
给定
attention(query, key, value, is_causal=False, scale=None)
函数的签名,你需要实现核心的注意力计算。在这里:query
,key
,value
是注意力机制中的三个张量,分别表示查询、键和值;is_causal
用于控制是否应用因果掩蔽;scale
控制是否进行缩放处理,通常在计算点积时除以键的维度的平方根。你可以参考 FlashAttention 等算法进行实现。
在
test_attention_kernel
中,我们提供了一个与 PyTorch 结果对比的测试函数compare_results
。你需要确保你的 Triton 实现与 PyTorch 的scaled_dot_product_attention
输出一致。请使用以下命令运行测试:3. 提交要求
attention
函数签名;attention
函数体以调用 Triton 内核;测试代码
希望这个作业帮助大家更好地理解 Triton 与深度学习中的并行计算。
Beta Was this translation helpful? Give feedback.
All reactions