187

Pruning is Optimal for Learning Sparse Features in High-Dimensions

Main:10 Pages
Bibliography:4 Pages
Appendix:46 Pages
Abstract

While it is commonly observed in practice that pruning networks to a certain level of sparsity can improve the quality of the features, a theoretical explanation of this phenomenon remains elusive. In this work, we investigate this by demonstrating that a broad class of statistical models can be optimally learned using pruned neural networks trained with gradient descent, in high-dimensions. We consider learning both single-index and multi-index models of the form y=σ(Vx)+ϵy = \sigma^*(\boldsymbol{V}^{\top} \boldsymbol{x}) + \epsilon, where σ\sigma^* is a degree-pp polynomial, and V\mathbbmRd×r\boldsymbol{V} \in \mathbbm{R}^{d \times r} with rdr \ll d, is the matrix containing relevant model directions. We assume that V\boldsymbol{V} satisfies a certain q\ell_q-sparsity condition for matrices and show that pruning neural networks proportional to the sparsity level of V\boldsymbol{V} improves their sample complexity compared to unpruned networks. Furthermore, we establish Correlational Statistical Query (CSQ) lower bounds in this setting, which take the sparsity level of V\boldsymbol{V} into account. We show that if the sparsity level of V\boldsymbol{V} exceeds a certain threshold, training pruned networks with a gradient descent algorithm achieves the sample complexity suggested by the CSQ lower bound. In the same scenario, however, our results imply that basis-independent methods such as models trained via standard gradient descent initialized with rotationally invariant random weights can provably achieve only suboptimal sample complexity.

View on arXiv
Comments on this paper