372

Scaling description of generalization with number of parameters in deep learning

Abstract

We provide a description for the evolution of the generalization performance of fixed-depth fully-connected deep neural networks, as a function of their number of parameters NN. As NN gets large, we observe that increasing NN at fixed depth reduces the fluctuations of the output function fNf_N induced by initial conditions, with fNfˉNN1/4\|f_N-{\bar f}_N\|\sim N^{-1/4} where fˉN{\bar f}_N denotes an average over initial conditions. We explain this asymptotic behavior in terms of the fluctuations of the so-called Neural Tangent Kernel that controls the dynamics of the output function. For the task of classification, we predict these fluctuations to increase the true test error ϵ\epsilon as ϵNϵN1/2+O(N3/4)\epsilon_{N}-\epsilon_{\infty}\sim N^{-1/2} + \mathcal{O}( N^{-3/4}). This prediction is consistent with our empirical results on the MNIST dataset, and it explains in a concrete case the puzzling observation that the predictive power of deep networks improves as the number of fitting parameters grows. For smaller NN, this asymptotic description breaks down at a so-called jamming transition which takes place at a critical N=NN=N^*, below which the training error is non-zero. In the absence of regularization, we observe an apparent divergence fN(NN)α\|f_N\| \sim (N-N^*)^{-\alpha} and provide a simple argument suggesting α=1\alpha=1, consistent with empirical observations. This result leads to a plausible explanation for the cusp in test error known to occur at NN^*. Overall, our analysis suggests that once models are averaged, the optimal model complexity is reached just beyond the point where the data can be perfectly fitted, a result of practical importance that needs to be tested in a wide range of architectures and dataset.

View on arXiv
Comments on this paper