52

Scalable Multi-Objective and Meta Reinforcement Learning via Gradient Estimation

Main:7 Pages
4 Figures
Bibliography:2 Pages
4 Tables
Appendix:8 Pages
Abstract

We study the problem of efficiently estimating policies that simultaneously optimize multiple objectives in reinforcement learning (RL). Given nn objectives (or tasks), we seek the optimal partition of these objectives into knk \ll n groups, where each group comprises related objectives that can be trained together. This problem arises in applications such as robotics, control, and preference optimization in language models, where learning a single policy for all nn objectives is suboptimal as nn grows. We introduce a two-stage procedure -- meta-training followed by fine-tuning -- to address this problem. We first learn a meta-policy for all objectives using multitask learning. Then, we adapt the meta-policy to multiple randomly sampled subsets of objectives. The adaptation step leverages a first-order approximation property of well-trained policy networks, which is empirically verified to be accurate within a 2%2\% error margin across various RL environments. The resulting algorithm, PolicyGradEx, efficiently estimates an aggregate task-affinity score matrix given a policy evaluation algorithm. Based on the estimated affinity score matrix, we cluster the nn objectives into kk groups by maximizing the intra-cluster affinity scores. Experiments on three robotic control and the Meta-World benchmarks demonstrate that our approach outperforms state-of-the-art baselines by 16%16\% on average, while delivering up to 26×26\times faster speedup relative to performing full training to obtain the clusters. Ablation studies validate each component of our approach. For instance, compared with random grouping and gradient-similarity-based grouping, our loss-based clustering yields an improvement of 19%19\%. Finally, we analyze the generalization error of policy networks by measuring the Hessian trace of the loss surface, which gives non-vacuous measures relative to the observed generalization errors.

View on arXiv
Comments on this paper