83

Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws

Main:12 Pages
6 Figures
Bibliography:5 Pages
1 Tables
Appendix:71 Pages
Abstract

We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as yj=1rλjσ(θj,x),xN(0,Id)y \propto \sum_{j=1}^{r}\lambda_j \sigma\left(\langle \boldsymbol{\theta_j}, \boldsymbol{x}\rangle\right), \boldsymbol{x} \sim N(0,\boldsymbol{I}_d), σ\sigma is the 2nd Hermite polynomial, and {θj}j=1rRd\lbrace\boldsymbol{\theta}_j \rbrace_{j=1}^{r} \subset \mathbb{R}^d are orthonormal signal directions. We consider the extensive-width regime rdβr \asymp d^\beta for β[0,1)\beta \in [0, 1), and assume a power-law decay on the (non-negative) second-layer coefficients λjjα\lambda_j\asymp j^{-\alpha} for α0\alpha \geq 0. We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.

View on arXiv
Comments on this paper