Sampling-based Distributed Training with Message Passing Neural Network

In this study, we introduce a domain-decomposition-based distributed training and inference approach for message-passing neural networks (MPNN). Our objective is to address the challenge of scaling edge-based graph neural networks as the number of nodes increases. Through our distributed training approach, coupled with Nyström-approximation sampling techniques, we present a scalable graph neural network, referred to as DS-MPNN (D and S standing for distributed and sampled, respectively), capable of scaling up to nodes. We validate our sampling and distributed training approach on two cases: (a) a Darcy flow dataset and (b) steady RANS simulations of 2-D airfoils, providing comparisons with both single-GPU implementation and node-based graph convolution networks (GCNs). The DS-MPNN model demonstrates comparable accuracy to single-GPU implementation, can accommodate a significantly larger number of nodes compared to the single-GPU variant (S-MPNN), and significantly outperforms the node-based GCN.
View on arXiv@article{kakka2025_2402.15106, title={ Sampling-based Distributed Training with Message Passing Neural Network }, author={ Priyesh Kakka and Sheel Nidhan and Rishikesh Ranade and Jay Pathak and Jonathan F. MacArt }, journal={arXiv preprint arXiv:2402.15106}, year={ 2025 } }