Skip to content

Conversation

@66RING
Copy link

@66RING 66RING commented Nov 10, 2025

a timestep embedding kernel implementation

TODO

  • triton kernel
  • cuda kernel
  • test, perf and report
    • tested on 4090
      • 3.2x speedup for cuda code.
      • 2.5x speedup for triton code.
    • other GPU
  • e2e test
    • e2e performance
    • e2e generation
  • remove all debug mark: TODO:, NOTE:
python -m unittest discover -s test/visual_embedding -p "test_timestep*.py" -v
W1117 22:43:28.146000 405083 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W1117 22:43:28.146000 405083 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
test_correctness (test_timestep_embedding_kernel.TestTimestepEmbed.test_correctness) ... ok
test_perf (test_timestep_embedding_kernel.TestTimestepEmbed.test_perf) ... === Timestep Embedding Benchmark Results ===
╒══════════════╤═════════════╤═══════════════════╤════════════════════╤══════════════════╤════════════════════╤══════════════════╕
│   Batch Size │   Dimension │   Torch Time (ms) │   Triton Time (ms) │   CUDA Time (ms) │   Speedup (Triton) │   Speedup (CUDA) │
╞══════════════╪═════════════╪═══════════════════╪════════════════════╪══════════════════╪════════════════════╪══════════════════╡
│            1 │          32 │          0.271989 │           0.109464 │          0.03409 │            2.48473 │          7.9779  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │          64 │          0.179808 │           0.069002 │          0.04043 │            2.60585 │          4.44699 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │         128 │          0.091829 │           0.042419 │          0.03079 │            2.16479 │          2.98223 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │         512 │          0.100058 │           0.040594 │          0.03375 │            2.46486 │          2.96506 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │         613 │          0.144597 │           0.038798 │          0.03511 │            3.72688 │          4.11816 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │        1024 │          0.107922 │           0.040682 │          0.03267 │            2.65284 │          3.30367 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │        4096 │          0.090890 │           0.043629 │          0.03647 │            2.08325 │          2.49237 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │        4099 │          0.116331 │           0.043619 │          0.03366 │            2.66697 │          3.45598 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │        8192 │          0.095914 │           0.042835 │          0.03231 │            2.23913 │          2.96821 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            1 │        8197 │          0.128936 │           0.060627 │          0.03019 │            2.1267  │          4.27121 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │          32 │          0.104677 │           0.039646 │          0.02884 │            2.64026 │          3.63017 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │          64 │          0.093830 │           0.040718 │          0.03217 │            2.30437 │          2.91631 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │         128 │          0.084614 │           0.045166 │          0.03040 │            1.87339 │          2.78366 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │         512 │          0.095648 │           0.044272 │          0.03387 │            2.16046 │          2.82434 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │         613 │          0.112157 │           0.042038 │          0.03116 │            2.66796 │          3.59901 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │        1024 │          0.090045 │           0.039230 │          0.03496 │            2.29528 │          2.57542 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │        4096 │          0.122835 │           0.039062 │          0.03218 │            3.14459 │          3.8176  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │        4099 │          0.124685 │           0.039438 │          0.03206 │            3.16151 │          3.88882 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │        8192 │          0.083885 │           0.042750 │          0.02955 │            1.9622  │          2.8384  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            2 │        8197 │          0.124118 │           0.047198 │          0.03916 │            2.62972 │          3.16965 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │          32 │          0.130493 │           0.039266 │          0.03219 │            3.32334 │          4.05438 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │          64 │          0.086528 │           0.038509 │          0.03108 │            2.24697 │          2.78375 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │         128 │          0.081915 │           0.044549 │          0.03098 │            1.83877 │          2.64406 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │         512 │          0.096869 │           0.038600 │          0.03493 │            2.50955 │          2.77351 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │         613 │          0.112197 │           0.050648 │          0.03390 │            2.21523 │          3.30972 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │        1024 │          0.104306 │           0.038992 │          0.03744 │            2.67505 │          2.78582 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │        4096 │          0.095330 │           0.043398 │          0.03251 │            2.19662 │          2.93271 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │        4099 │          0.115037 │           0.044704 │          0.03152 │            2.5733  │          3.65002 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │        8192 │          0.086014 │           0.041253 │          0.03324 │            2.08506 │          2.58743 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│            8 │        8197 │          0.115706 │           0.039346 │          0.03137 │            2.94075 │          3.68865 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │          32 │          0.084331 │           0.041794 │          0.03120 │            2.0178  │          2.70251 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │          64 │          0.085856 │           0.047995 │          0.03346 │            1.78885 │          2.56562 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │         128 │          0.088240 │           0.040171 │          0.03398 │            2.1966  │          2.59701 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │         512 │          0.098192 │           0.039211 │          0.03201 │            2.50418 │          3.06743 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │         613 │          0.112726 │           0.038378 │          0.03167 │            2.9373  │          3.55954 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │        1024 │          0.085674 │           0.039834 │          0.03185 │            2.15079 │          2.68994 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │        4096 │          0.217578 │           0.041019 │          0.03266 │            5.30429 │          6.66174 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │        4099 │          0.118774 │           0.039573 │          0.03204 │            3.00142 │          3.70744 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │        8192 │          0.091096 │           0.042040 │          0.03385 │            2.16689 │          2.6912  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│           63 │        8197 │          0.124536 │           0.045085 │          0.03243 │            2.76226 │          3.8401  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │          32 │          0.085251 │           0.038451 │          0.03274 │            2.21713 │          2.60395 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │          64 │          0.090818 │           0.039843 │          0.03360 │            2.27938 │          2.70278 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │         128 │          0.089048 │           0.040128 │          0.03248 │            2.2191  │          2.7419  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │         512 │          0.095618 │           0.040096 │          0.03381 │            2.38472 │          2.82825 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │         613 │          0.113858 │           0.040682 │          0.03343 │            2.79875 │          3.40565 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │        1024 │          0.092723 │           0.047123 │          0.03372 │            1.96768 │          2.74954 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │        4096 │          0.097002 │           0.043013 │          0.03506 │            2.25518 │          2.76667 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │        4099 │          0.116541 │           0.043736 │          0.03437 │            2.66464 │          3.39097 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │        8192 │          0.097611 │           0.047749 │          0.03714 │            2.04426 │          2.62848 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          256 │        8197 │          0.127843 │           0.047734 │          0.03611 │            2.67822 │          3.54034 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │          32 │          0.082366 │           0.041538 │          0.03222 │            1.98294 │          2.55669 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │          64 │          0.138365 │           0.040026 │          0.03654 │            3.45691 │          3.78675 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │         128 │          0.083654 │           0.040470 │          0.03320 │            2.06705 │          2.51971 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │         512 │          0.090722 │           0.039478 │          0.03300 │            2.29801 │          2.74887 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │         613 │          0.129771 │           0.044115 │          0.03421 │            2.94164 │          3.79359 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │        1024 │          0.092245 │           0.043085 │          0.03516 │            2.14101 │          2.62321 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │        4096 │          0.093203 │           0.046376 │          0.05712 │            2.00973 │          1.6318  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │        4099 │          0.118640 │           0.053816 │          0.03756 │            2.20455 │          3.15868 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │        8192 │          0.115054 │           0.058421 │          0.05282 │            1.96941 │          2.1782  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│          613 │        8197 │          0.127667 │           0.056478 │          0.04460 │            2.26046 │          2.86249 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │          32 │          0.095477 │           0.072920 │          0.03639 │            1.30934 │          2.62368 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │          64 │          0.101080 │           0.039322 │          0.03251 │            2.5706  │          3.10901 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │         128 │          0.087342 │           0.039102 │          0.03254 │            2.23368 │          2.68448 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │         512 │          0.089330 │           0.040786 │          0.04412 │            2.19022 │          2.02448 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │         613 │          0.113472 │           0.039738 │          0.05112 │            2.85553 │          2.21986 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │        1024 │          0.312182 │           0.049955 │          0.03504 │            6.24925 │          8.90972 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │        4096 │          0.119824 │           0.057813 │          0.05157 │            2.07262 │          2.3234  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │        4099 │          0.122275 │           0.059843 │          0.04276 │            2.04326 │          2.85968 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │        8192 │          0.148262 │           0.075445 │          0.07545 │            1.96518 │          1.96497 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1024 │        8197 │          0.184182 │           0.089733 │          0.06918 │            2.05256 │          2.6624  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │          32 │          0.106229 │           0.046114 │          0.03398 │            2.30363 │          3.12629 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │          64 │          0.170469 │           0.064019 │          0.03328 │            2.66278 │          5.123   │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │         128 │          0.093403 │           0.041966 │          0.03353 │            2.22567 │          2.7853  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │         512 │          0.157189 │           0.045018 │          0.04044 │            3.49172 │          3.88696 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │         613 │          0.214784 │           0.130040 │          0.10653 │            1.65168 │          2.0161  │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │        1024 │          0.123058 │           0.050464 │          0.04100 │            2.43852 │          3.00117 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │        4096 │          0.118032 │           0.061974 │          0.05943 │            1.90453 │          1.98605 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │        4099 │          0.149854 │           0.066464 │          0.04925 │            2.25467 │          3.04295 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │        8192 │          0.236622 │           0.089696 │          0.09722 │            2.63805 │          2.43391 │
├──────────────┼─────────────┼───────────────────┼────────────────────┼──────────────────┼────────────────────┼──────────────────┤
│         1536 │        8197 │          0.350042 │           0.133008 │          0.10504 │            2.63173 │          3.33261 │
╘══════════════╧═════════════╧═══════════════════╧════════════════════╧══════════════════╧════════════════════╧══════════════════╛ok

----------------------------------------------------------------------
Ran 2 tests in 13.685s

OK

Average Speedup(triton): 2.4760
Average Speedup(cuda): 3.1957

@66RING 66RING requested a review from mickqian as a code owner November 10, 2025 13:15
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @66RING, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the sglang library by integrating a high-performance Triton kernel for timestep embedding. This optimization is crucial for diffusion models and aims to significantly improve the computational efficiency of this operation, ensuring faster processing while maintaining accuracy.

Highlights

  • Triton Kernel Implementation: Introduced a new Triton-based kernel, _timestep_embedding_triton_kernel, and its Python wrapper, timestep_embedding_triton, for efficient sinusoidal timestep embedding calculations.
  • Performance and Correctness Testing: Added a comprehensive test suite (test_timestep_embedding_kernel.py) to verify the correctness of the Triton kernel against the existing PyTorch implementation and to benchmark its performance across various batch sizes and dimensions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton kernel for timestep embedding, which is a great optimization. The implementation correctly mirrors the existing PyTorch logic. My review focuses on improving robustness, simplifying the kernel code, and enhancing test coverage.

I've identified a significant edge case where dim < 2 would cause a division-by-zero error in both the new Triton kernel and the original PyTorch function. I've provided a suggestion to handle this in the host function to make it more robust. Additionally, I've suggested a couple of simplifications within the Triton kernel for better readability and performance. Finally, I've noted that the dtype test case is unimplemented and offered a sample implementation to ensure the kernel works correctly with different precisions like float16.

Comment on lines 1032 to 1035
B = t.shape[0]
assert t.is_cuda, "t should be a CUDA tensor"

output = torch.empty((B, dim), dtype=dtype, device='cuda')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel _timestep_embedding_triton_kernel will fail with a division-by-zero error if dim < 2, because half will be 0. While the reference PyTorch implementation appears to have the same issue, it's good practice to make this function robust against such edge cases. I suggest adding a check at the beginning of the function to handle small dim values before calling the kernel.

    B = t.shape[0]
    assert t.is_cuda, "t should be a CUDA tensor"

    if dim < 2:
        if dim == 0:
            return torch.empty((B, 0), dtype=dtype, device='cuda')
        else:  # dim == 1
            return torch.zeros((B, 1), dtype=dtype, device='cuda')

    output = torch.empty((B, dim), dtype=dtype, device='cuda')

Comment on lines 988 to 992
freq_indices = tl.where(
is_first_half,
d_offsets,
tl.where(is_second_half, d_offsets - half, 0)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of freq_indices can be simplified. The nested tl.where is equivalent to a modulo operation, which is more concise and potentially more performant. Once the division-by-zero issue for dim < 2 is handled in the host function, half will be guaranteed to be >= 1 in the kernel, making the modulo operation safe.

    freq_indices = d_offsets % half


# Calculate freqs and angles
dtype = output_ptr.dtype.element_ty
log_max_period = tl.log(tl.full((BLOCK_SIZE_DIM,), max_period, dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The creation of log_max_period can be made more efficient. tl.full((BLOCK_SIZE_DIM,), ...) creates a vector, but since max_period is a scalar, a scalar operation is sufficient and will be broadcasted automatically by Triton. Using float(max_period) is safe as the subsequent calculations promote the result to float32 anyway.

    log_max_period = tl.log(float(max_period))



def test_dtype(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test case for dtype is currently empty. It's important to verify that the Triton kernel works correctly with different data types, such as torch.float16, which is common for performance. Please consider implementing this test to ensure the kernel is robust.

Suggested change
pass
device = "cuda"
for dtype in [torch.float32, torch.float16]:
# Use a representative batch size and dimension
B, dim = 16, 256
t = torch.randn((B,), device=device)
torch_output = timestep_embedding(t, dim, dtype=dtype)
triton_output = timestep_embedding_triton(t, dim, dtype=dtype)
# Use a larger tolerance for float16
atol = 1e-2 if dtype == torch.float16 else 1e-6
assert torch.allclose(torch_output, triton_output, atol=atol), f"Mismatch for dtype={dtype}"

@mickqian
Copy link
Collaborator

mickqian commented Nov 10, 2025

we might need to perform a full examination on the embedders in:

  1. visual_embedding.py
  2. WanTimeTextImageEmbedding
  3. QwenTimestepProjEmbeddings (which I think some of it could already be replaced with components in visual_embedding.py)
  4. time_text_embed in flux.py

And put common reusable components in visual_embeddings.py or other appropriate places

using namespace flashinfer;

// // TODO: debug only for now
// #include "sgl_kernel_ops.h"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea where to place the CUDA code for sgld. Is SGLD now a separate directory from llm code?

}
}

template <typename T> __device__ __nv_bfloat16 convert_to_bfloat16(T x) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusable code. hard code here for debug now. Should I use flashinfer style things?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}()

// TODO:
// assert operations is float??
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the python code always return float32. So hard code float32 for now.

// Heuristics tuning
// WARN: which will generate a lot template function:
// (DIM_SWITCH * DISPATCH_FLOAT_TYPES).
DIM_SWITCH(dim, kDim, /* bad case */ 1, [&] {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static switch may cause a long compile time. Always vec_size=1 may be a good choice? Any idea?

)


class TestTimestepEmbed(unittest.TestCase):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style of unittest is a bit different from LLM part. Simply copy codebase from ./test folder.

@mickqian mickqian requested a review from BBuf November 18, 2025 01:23
@mickqian
Copy link
Collaborator

added @BBuf for review. much thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants