19
3

Neural Stein critics with staged L2L^2-regularization

Matthew Repasky
Xiuyuan Cheng
Yao Xie
Abstract

Learning to differentiate model distributions from observed data is a fundamental problem in statistics and machine learning, and high-dimensional data remains a challenging setting for such problems. Metrics that quantify the disparity in probability distributions, such as the Stein discrepancy, play an important role in high-dimensional statistical testing. In this paper, we investigate the role of L2L^2 regularization in training a neural network Stein critic so as to distinguish between data sampled from an unknown probability distribution and a nominal model distribution. Making a connection to the Neural Tangent Kernel (NTK) theory, we develop a novel staging procedure for the weight of regularization over training time, which leverages the advantages of highly-regularized training at early times. Theoretically, we prove the approximation of the training dynamic by the kernel optimization, namely the ``lazy training'', when the L2L^2 regularization weight is large, and training on nn samples converge at a rate of O(n1/2){O}(n^{-1/2}) up to a log factor. The result guarantees learning the optimal critic assuming sufficient alignment with the leading eigen-modes of the zero-time NTK. The benefit of the staged L2L^2 regularization is demonstrated on simulated high dimensional data and an application to evaluating generative models of image data.

View on arXiv
Comments on this paper