140

Provable In-context Learning for Mixture of Linear Regressions using Transformers

Main:11 Pages
4 Figures
Bibliography:5 Pages
Appendix:36 Pages
Abstract

We theoretically investigate the in-context learning capabilities of transformers in the context of learning mixtures of linear regression models. For the case of two mixtures, we demonstrate the existence of transformers that can achieve an accuracy, relative to the oracle predictor, of order O~((d/n)1/4)\mathcal{\tilde{O}}((d/n)^{1/4}) in the low signal-to-noise ratio (SNR) regime and O~(d/n)\mathcal{\tilde{O}}(\sqrt{d/n}) in the high SNR regime, where nn is the length of the prompt, and dd is the dimension of the problem. Additionally, we derive in-context excess risk bounds of order O(L/B)\mathcal{O}(L/\sqrt{B}), where BB denotes the number of (training) prompts, and LL represents the number of attention layers. The order of LL depends on whether the SNR is low or high. In the high SNR regime, we extend the results to KK-component mixture models for finite KK. Extensive simulations also highlight the advantages of transformers for this task, outperforming other baselines such as the Expectation-Maximization algorithm.

View on arXiv
Comments on this paper