62
3

Sampling-based Distributed Training with Message Passing Neural Network

Abstract

In this study, we introduce a domain-decomposition-based distributed training and inference approach for message-passing neural networks (MPNN). Our objective is to address the challenge of scaling edge-based graph neural networks as the number of nodes increases. Through our distributed training approach, coupled with Nyström-approximation sampling techniques, we present a scalable graph neural network, referred to as DS-MPNN (D and S standing for distributed and sampled, respectively), capable of scaling up to O(105)O(10^5) nodes. We validate our sampling and distributed training approach on two cases: (a) a Darcy flow dataset and (b) steady RANS simulations of 2-D airfoils, providing comparisons with both single-GPU implementation and node-based graph convolution networks (GCNs). The DS-MPNN model demonstrates comparable accuracy to single-GPU implementation, can accommodate a significantly larger number of nodes compared to the single-GPU variant (S-MPNN), and significantly outperforms the node-based GCN.

View on arXiv
@article{kakka2025_2402.15106,
  title={ Sampling-based Distributed Training with Message Passing Neural Network },
  author={ Priyesh Kakka and Sheel Nidhan and Rishikesh Ranade and Jay Pathak and Jonathan F. MacArt },
  journal={arXiv preprint arXiv:2402.15106},
  year={ 2025 }
}
Comments on this paper