193

Learning Hierarchical Polynomials with Three-Layer Neural Networks

Abstract

We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form h=gph = g \circ p where p:RdRp : \mathbb{R}^d \rightarrow \mathbb{R} is a degree kk polynomial and g:RRg: \mathbb{R} \rightarrow \mathbb{R} is a degree qq polynomial. This function class generalizes the single-index model, which corresponds to k=1k=1, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree kk polynomials pp, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target hh up to vanishing test error in O~(dk)\widetilde{\mathcal{O}}(d^k) samples and polynomial time. This is a strict improvement over kernel methods, which require Θ~(dkq)\widetilde \Theta(d^{kq}) samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of pp being a quadratic. When pp is indeed a quadratic, we achieve the information-theoretically optimal sample complexity O~(d2)\widetilde{\mathcal{O}}(d^2), which is an improvement over prior work~\citep{nichani2023provable} requiring a sample size of Θ~(d4)\widetilde\Theta(d^4). Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature pp with O~(dk)\widetilde{\mathcal{O}}(d^k) samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.

View on arXiv
Comments on this paper