The goal of this paper is to investigate the complexity of gradient algorithms when learning sparse functions (juntas). We introduce a type of Statistical Queries (), which we call Differentiable Learning Queries (), to model gradient queries on a specified loss with respect to an arbitrary model. We provide a tight characterization of the query complexity of for learning the support of a sparse function over generic product distributions. This complexity crucially depends on the loss function. For the squared loss, matches the complexity of Correlation Statistical Queries --potentially much worse than . But for other simple loss functions, including the loss, always achieves the same complexity as . We also provide evidence that can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.
View on arXiv