59
0

Distributed Sign Momentum with Local Steps for Training Transformers

Abstract

Pre-training Transformer models is resource-intensive, and recent studies have shown that sign momentum is an efficient technique for training large-scale deep learning models, particularly Transformers. However, its application in distributed training remains underexplored. This paper investigates a novel communication-efficient distributed sign momentum method with multiple local steps, to cope with the scenarios where communicating at every step is prohibitive. Our proposed method allows for a broad class of base optimizers for local steps, and uses sign momentum in the global step, where momentum is generated from differences accumulated during local steps. For generic base optimizers, by approximating the sign operator with a randomized version that acts as a continuous analog in expectation, we present a general convergence analysis, which specializes to an O(1/T)O(1/\sqrt{T}) rate for a particular instance. When local step is stochastic gradient descent, we show an optimal O(1/T1/4)O(1/T^{1/4}) rate in terms of 1\ell_1 gradient norm for nonconvex smooth cost functions. We extensively evaluate our method on the pre-training of various sized GPT-2 models from scratch, and the empirical results show significant improvement compared to other distributed methods with multiple local steps.

View on arXiv
@article{yu2025_2411.17866,
  title={ Distributed Sign Momentum with Local Steps for Training Transformers },
  author={ Shuhua Yu and Ding Zhou and Cong Xie and An Xu and Zhi Zhang and Xin Liu and Soummya Kar },
  journal={arXiv preprint arXiv:2411.17866},
  year={ 2025 }
}
Comments on this paper