44

Exact Causal Attention with 10% Fewer Operations

Dmitry Rybin
Yushun Zhang
Ding Tian
Zhihang Lin
Ruoyu Sun
Zhi-Quan Luo
Main:9 Pages
2 Figures
Bibliography:4 Pages
4 Tables
Abstract

We present Fast Causal Attention (FCA), an algorithm that computes exact Causal Attention using 10\% fewer operations. FCA accelerates a special class of matrix multiplications where either one operand or the output matrix is upper- or lower-triangular. This includes all operations in forward and backward pass of Causal Attention, such as masked product Mask(QKT)\mathrm{Mask}(QK^{T}). For these matrix multiplications on GPU, FCA reaches noticeable accelerations over the default PyTorch implementations and Triton compiled kernels. FCA is built upon algebraic identities discovered via machine learning and combinatorial search.

View on arXiv
Comments on this paper