11
2

FLuRKA: Fast and accurate unified Low-Rank & Kernel Attention

Abstract

Many efficient approximate\textit{approximate} self-attention techniques have become prevalent since the inception of the transformer architecture. Two popular classes of these techniques are low-rank and kernel methods. Each of these methods has its strengths. We observe these strengths synergistically complement each other and exploit them to fuse low-rank and kernel methods, producing a new class of transformers: FLuRKA (F\textbf{F}ast L\textbf{L}ow-R\textbf{R}ank & K\textbf{K}ernelA \textbf{A}ttention). FLuRKA are highly training-efficient\textit{training-efficient} with faster model speeds and\textit{and} similar model qualities compared to constituent low-rank and kernel methods. We theoretically and empirically evaluate the speed and quality of FLuRKA. Our model speed analysis posits a variety of parameter configurations where FLuRKA exhibit speedups over low-rank and kernel approximations and our model quality analysis bounds the error of FLuRKA with respect to full-attention. Empirically, we instantiate three FLuRKA variants which experience speedups of up to 3.3x and 1.7x over low-rank and kernel methods respectively. This translates to speedups of up to 20x over models with flash-attention. Across a diverse set of tasks spanning language modeling, language understanding, long sequence modeling, machine translation, and image classification, FLuRKA achieve comparable accuracy with underlying low-rank and kernel approximations, occasionally surpassing both.

View on arXiv
Comments on this paper