JaxPruner: A concise library for sparsity research
Jooyoung Lee
Wonpyo Park
Nicole Mitchell
Jonathan Pilault
J. Obando-Ceron
Han-Byul Kim
Namhoon Lee
Elias Frantar
Yun Long
Amir Yazdanbakhsh
Shivani Agrawal
Suvinay Subramanian
Xin Wang
Sheng-Chun Kao
Xingyao Zhang
Trevor Gale
Aart J. C. Bik
Woohyun Han
Milen Ferev
Zhonglin Han
Hong-Seok Kim
Yann N. Dauphin
Karolina Dziugaite
P. S. Castro
Utku Evci

Abstract
This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
View on arXivComments on this paper