18
3

State-space models can learn in-context by gradient descent

Abstract

Deep state-space models (Deep SSMs) are becoming popular as effective approaches to model sequence data. They have also been shown to be capable of in-context learning, much like transformers. However, a complete picture of how SSMs might be able to do in-context learning has been missing. In this study, we provide a direct and explicit construction to show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers. Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model with least squares loss after one step of gradient descent. We then show a straightforward extension to multi-step linear and non-linear regression tasks. We validate our construction by training randomly initialized augmented SSMs on linear and non-linear regression tasks. The empirically obtained parameters through optimization match the ones predicted analytically by the theoretical construction. Overall, we elucidate the role of input- and output-gating in recurrent architectures as the key inductive biases for enabling the expressive power typical of foundation models. We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.

View on arXiv
@article{sushma2025_2410.11687,
  title={ State-space models can learn in-context by gradient descent },
  author={ Neeraj Mohan Sushma and Yudou Tian and Harshvardhan Mestha and Nicolo Colombo and David Kappel and Anand Subramoney },
  journal={arXiv preprint arXiv:2410.11687},
  year={ 2025 }
}
Comments on this paper