-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[Diffusion] timestep embedding kernel implementation #12995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @66RING, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Triton kernel for timestep embedding, which is a great optimization. The implementation correctly mirrors the existing PyTorch logic. My review focuses on improving robustness, simplifying the kernel code, and enhancing test coverage.
I've identified a significant edge case where dim < 2 would cause a division-by-zero error in both the new Triton kernel and the original PyTorch function. I've provided a suggestion to handle this in the host function to make it more robust. Additionally, I've suggested a couple of simplifications within the Triton kernel for better readability and performance. Finally, I've noted that the dtype test case is unimplemented and offered a sample implementation to ensure the kernel works correctly with different precisions like float16.
| B = t.shape[0] | ||
| assert t.is_cuda, "t should be a CUDA tensor" | ||
|
|
||
| output = torch.empty((B, dim), dtype=dtype, device='cuda') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel _timestep_embedding_triton_kernel will fail with a division-by-zero error if dim < 2, because half will be 0. While the reference PyTorch implementation appears to have the same issue, it's good practice to make this function robust against such edge cases. I suggest adding a check at the beginning of the function to handle small dim values before calling the kernel.
B = t.shape[0]
assert t.is_cuda, "t should be a CUDA tensor"
if dim < 2:
if dim == 0:
return torch.empty((B, 0), dtype=dtype, device='cuda')
else: # dim == 1
return torch.zeros((B, 1), dtype=dtype, device='cuda')
output = torch.empty((B, dim), dtype=dtype, device='cuda')| freq_indices = tl.where( | ||
| is_first_half, | ||
| d_offsets, | ||
| tl.where(is_second_half, d_offsets - half, 0) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of freq_indices can be simplified. The nested tl.where is equivalent to a modulo operation, which is more concise and potentially more performant. Once the division-by-zero issue for dim < 2 is handled in the host function, half will be guaranteed to be >= 1 in the kernel, making the modulo operation safe.
freq_indices = d_offsets % half|
|
||
| # Calculate freqs and angles | ||
| dtype = output_ptr.dtype.element_ty | ||
| log_max_period = tl.log(tl.full((BLOCK_SIZE_DIM,), max_period, dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The creation of log_max_period can be made more efficient. tl.full((BLOCK_SIZE_DIM,), ...) creates a vector, but since max_period is a scalar, a scalar operation is sufficient and will be broadcasted automatically by Triton. Using float(max_period) is safe as the subsequent calculations promote the result to float32 anyway.
log_max_period = tl.log(float(max_period))|
|
||
|
|
||
| def test_dtype(self): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case for dtype is currently empty. It's important to verify that the Triton kernel works correctly with different data types, such as torch.float16, which is common for performance. Please consider implementing this test to ensure the kernel is robust.
| pass | |
| device = "cuda" | |
| for dtype in [torch.float32, torch.float16]: | |
| # Use a representative batch size and dimension | |
| B, dim = 16, 256 | |
| t = torch.randn((B,), device=device) | |
| torch_output = timestep_embedding(t, dim, dtype=dtype) | |
| triton_output = timestep_embedding_triton(t, dim, dtype=dtype) | |
| # Use a larger tolerance for float16 | |
| atol = 1e-2 if dtype == torch.float16 else 1e-6 | |
| assert torch.allclose(torch_output, triton_output, atol=atol), f"Mismatch for dtype={dtype}" |
|
we might need to perform a full examination on the embedders in:
And put common reusable components in |
| using namespace flashinfer; | ||
|
|
||
| // // TODO: debug only for now | ||
| // #include "sgl_kernel_ops.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea where to place the CUDA code for sgld. Is SGLD now a separate directory from llm code?
| } | ||
| } | ||
|
|
||
| template <typename T> __device__ __nv_bfloat16 convert_to_bfloat16(T x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reusable code. hard code here for debug now. Should I use flashinfer style things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @FlamingoPg
| }() | ||
|
|
||
| // TODO: | ||
| // assert operations is float?? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the python code always return float32. So hard code float32 for now.
| // Heuristics tuning | ||
| // WARN: which will generate a lot template function: | ||
| // (DIM_SWITCH * DISPATCH_FLOAT_TYPES). | ||
| DIM_SWITCH(dim, kDim, /* bad case */ 1, [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static switch may cause a long compile time. Always vec_size=1 may be a good choice? Any idea?
| ) | ||
|
|
||
|
|
||
| class TestTimestepEmbed(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style of unittest is a bit different from LLM part. Simply copy codebase from ./test folder.
|
added @BBuf for review. much thanks |
a timestep embedding kernel implementation
TODO
TODO:,NOTE: