17
78

Fast Attention Requires Bounded Entries

Abstract

In modern machine learning, inner product attention computation is a fundamental task for training large language models such as Transformer, GPT-1, BERT, GPT-2, GPT-3 and ChatGPT. Formally, in this problem, one is given as input three matrices Q,K,V[B,B]n×dQ, K, V \in [-B,B]^{n \times d}, and the goal is to construct the matrix Att(Q,K,V):=diag(A1n)1AVRn×d\mathrm{Att}(Q,K,V) := \mathrm{diag}(A {\bf 1}_n)^{-1} A V \in \mathbb{R}^{n \times d}, where A=exp(QK/d)A = \exp(QK^\top/d) is the `attention matrix', and exp\exp is applied entry-wise. Straightforward methods for this problem explicitly compute the n×nn \times n attention matrix AA, and hence require time Ω(n2)\Omega(n^2) even when d=no(1)d = n^{o(1)} is small. In this paper, we investigate whether faster algorithms are possible by implicitly making use of the matrix AA. We present two results, showing that there is a sharp transition at B=Θ(logn)B = \Theta(\sqrt{\log n}). \bullet If d=O(logn)d = O(\log n) and B=o(logn)B = o(\sqrt{\log n}), there is an n1+o(1)n^{1+o(1)} time algorithm to approximate Att(Q,K,V)\mathrm{Att}(Q,K,V) up to 1/poly(n)1/\mathrm{poly}(n) additive error. \bullet If d=O(logn)d = O(\log n) and B=Θ(logn)B = \Theta (\sqrt{\log n}), assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory, it is impossible to approximate Att(Q,K,V)\mathrm{Att}(Q,K,V) up to 1/poly(n)1/\mathrm{poly}(n) additive error in truly subquadratic time n2Ω(1)n^{2 - \Omega(1)}. This gives a theoretical explanation for the phenomenon observed in practice that attention computation is much more efficient when the input matrices have smaller entries.

View on arXiv
Comments on this paper