Fast Differentiable Clipping-Aware Normalization and Rescaling

Rescaling a vector to a desired length is a common operation in many areas such as data science and machine learning. When the rescaled perturbation is added to a starting point (where is the data domain, e.g. ), the resulting vector will in general not be in . To enforce that the perturbed vector is in , the values of can be clipped to . This subsequent element-wise clipping to the data domain does however reduce the effective perturbation size and thus interferes with the rescaling of . The optimal rescaling to obtain a perturbation with the desired norm after the clipping can be iteratively approximated using a binary search. However, such an iterative approach is slow and non-differentiable. Here we show that the optimal rescaling can be found analytically using a fast and differentiable algorithm. Our algorithm works for any p-norm and can be used to train neural networks on inputs with normalized perturbations. We provide native implementations for PyTorch, TensorFlow, JAX, and NumPy based on EagerPy.
View on arXiv