89
v1v2v3 (latest)

FlashKAT: Understanding and Addressing Performance Bottlenecks in the Kolmogorov-Arnold Transformer

Main:6 Pages
4 Figures
Bibliography:2 Pages
8 Tables
Appendix:1 Pages
Abstract

The Kolmogorov-Arnold Network (KAN) has been gaining popularity as an alternative to the multilayer perceptron (MLP) due to its greater expressiveness and interpretability. Even so, KAN suffers from training instability and being orders of magnitude slower due to its increased computational cost, limiting its applicability to large-scale tasks. Recently, the Kolmogorov-Arnold Transformer (KAT) has been proposed, achieving FLOPs comparable to traditional Transformer models with MLPs by leveraging Group-Rational KAN (GR-KAN). Unfortunately, despite the comparable FLOPs, our testing shows that KAT remains 123x slower during training, indicating that there are other performance bottlenecks beyond FLOPs. In this paper, we conduct a series of experiments to understand the root cause of the slowdown in KAT. We uncover that the slowdown can be isolated to memory stalls, linked more specifically to inefficient gradient accumulations in the backward pass of GR-KAN. To address this memory bottleneck, we propose FlashKAT, which minimizes accesses to slow memory and the usage of atomic adds through a restructured kernel. Evaluations show that FlashKAT achieves up to an 86.5x training speedup over state-of-the-art KAT while reducing rounding errors in gradient computation.

View on arXiv
Comments on this paper