557
v1v2v3 (latest)

Active Preference Optimization for Sample Efficient RLHF

Main:1 Pages
2 Figures
2 Tables
Appendix:30 Pages
Abstract

Large Language Models (LLMs) aligned using Reinforcement Learning from Human Feedback (RLHF) have shown remarkable generation abilities in numerous tasks. However, collecting high-quality human preferences creates costly bottlenecks in practical deployments, and hence, training data are often budgeted. In these scenarios, it is crucial to collect training data (e.g., contexts, a pair of generations for each context, and a preference indicating which generation is better) carefully, yet most of the existing methods sample contexts uniformly at random from a given collection. Given this, under the Bradley-Terry-Luce preference model and with a small budget of training data, we show that uniform sampling of contexts could lead to a policy (i.e., an aligned model) that suffers a constant sub-optimality gap from the optimal policy. This highlights the need for an adaptive context sampling strategy for effective alignment under a small sample budget. To address this, we reformulate RLHF within the contextual preference bandit framework, treating generations as actions, and give a nearly complete characterization of the sub-optimality gap in terms of both lower and upper bounds. First, when the action set is a dd-dimensional hypercube and the number of samples is TT, we show an Ω(d/T)\Omega(d/\sqrt{T}) lower bound. Next, we propose an algorithm, Active Preference Optimization\textit{Active Preference Optimization} (APO\texttt{APO}), that iteratively collects preferences for the most uncertain contexts. We show that the sub-optimality gap of the policy learned via APO\texttt{APO} matches the lower bound up to a log factor and a non-linearity constant. Finally, we perform experiments on practical datasets to validate APO\texttt{APO}'s efficacy over existing methods, establishing it as a sample-efficient and cost-effective solution for LLM alignment.

View on arXiv
Comments on this paper