
FlashAttention-2 in CuTe, from scratch -- a line-by-line walkthrough
Hey everyone, I spent a few months learning CuTe by re-implementing FA-2 from scratch on Ampere, then wrote up a thorough walkthrough of every important line in Tri Dao's source code.
CuTe notoriously has an extremely high learning curve and is extremely hard to interpret. Most of what's online about CuTe is either NVIDIA's reference docs (not really a beginner guide), the production source code, or partial deep-dives that cover one concept at a time. I tried to fill the gap by walking the whole kernel end-to-end at a depth where you can fully understand why each decision was made.
The blog is not a true beginner's guide in the sense that it's not for those who have never touched a kernel -- but, I tried to make it as accessible to anyone who only have a vague notion of even the most basic CUDA concepts.
We cover: swizzling and bank conflicts, tiled MMAs and fragment layouts, the LDSM atoms, V-transpose, online softmax via warp reductions, async copy pipelining, and the output store. I made my own diagrams and even include some code improvements as well.
The kernel hits the full performance of production FA-2's throughput on A100 (it's close to the exact algorithm, just stripped into its essential core).
Hoping this is useful for anyone trying to ramp on CuTe or read Tri Dao's source. Happy to answer questions in the comments!