18
0

LevAttention: Time, Space, and Streaming Efficient Algorithm for Heavy Attentions

Abstract

A central problem related to transformers can be stated as follows: given two n×dn \times d matrices QQ and KK, and a non-negative function ff, define the matrix AA as follows: (1) apply the function ff to each entry of the n×nn \times n matrix QKTQ K^T, and then (2) normalize each of the row sums of AA to be equal to 11. The matrix AA can be computed in O(n2d)O(n^2 d) time assuming ff can be applied to a number in constant time, but the quadratic dependence on nn is prohibitive in applications where it corresponds to long context lengths. For a large class of functions ff, we show how to find all the ``large attention scores", i.e., entries of AA which are at least a positive value ε\varepsilon, in time with linear dependence on nn (i.e., npoly(d/ε)n \cdot \textrm{poly}(d/\varepsilon)) for a positive parameter ε>0\varepsilon > 0. Our class of functions include all functions ff of the form f(x)=xpf(x) = |x|^p, as explored recently in transformer models. Using recently developed tools from randomized numerical linear algebra, we prove that for any KK, there is a ``universal set" U[n]U \subset [n] of size independent of nn, such that for any QQ and any row ii, the large attention scores Ai,jA_{i,j} in row ii of AA all have jUj \in U. We also find UU in npoly(d/ε)n \cdot \textrm{poly}(d/\varepsilon) time. Notably, we (1) make no assumptions on the data, (2) our workspace does not grow with nn, and (3) our algorithms can be computed in streaming and parallel settings. We call the attention mechanism that uses only the subset of keys in the universal set as LevAttention since our algorithm to identify the universal set UU is based on leverage scores. We empirically show the benefits of our scheme for vision transformers, showing how to train new models that use our universal set while training as well, showing that our model is able to consistently select ``important keys'' during training.

View on arXiv
Comments on this paper