40
3

Trainable Weight Averaging: Accelerating Training and Improving Generalization

Abstract

Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available atthis https URL.

View on arXiv
@article{li2025_2205.13104,
  title={ Trainable Weight Averaging: Accelerating Training and Improving Generalization },
  author={ Tao Li and Zhehao Huang and Yingwen Wu and Zhengbao He and Qinghua Tao and Xiaolin Huang and Chih-Jen Lin },
  journal={arXiv preprint arXiv:2205.13104},
  year={ 2025 }
}
Comments on this paper