Poster
InfAlign: Inference-aware language model alignment
Ananth Balashankar · Ziteng Sun · Jonathan Berant · Jacob Eisenstein · Michael Collins · Adrian Hutter · Jong Lee · Chirag Nagpal · Flavien Prost · Aradhana Sinha · Ananda Suresh · Ahmad Beirami
East Exhibition Hall A-B #E-2804
Abstract:
Language model alignment is a critical stepin training modern generative language models.Alignment targets to improve win rate of a samplefrom the aligned model against the base model.Today, we are increasingly using inference-timealgorithms (e.g., Best-of-$N$ , controlled decoding, tree search) to decode from language modelsrather than standard sampling. We show that thistrain/test mismatch makes standard RLHF framework sub-optimal in view of such inference-timemethods. To this end, we propose a framework forinference-aware alignment (InfAlign), whichaims to optimize *inference-time win rate* of thealigned policy against the base model. We provethat for any inference-time decoding procedure,the optimal aligned policy is the solution to thestandard RLHF problem with a *transformation*of the reward. This motivates us to provide thecalibrate-and-transform RL (InfAlign-CTRL)algorithm to solve this problem, which involvesa reward calibration step and a KL-regularizedreward maximization step with a transformationof the calibrated reward. For best-of-$N$ samplingand best-of-$N$ jailbreaking, we propose specifictransformations offering up to 3-8% improvementon inference-time win rates. Finally, we also showthat our proposed reward calibration method is astrong baseline for optimizing standard win rate.
Lay Summary:
Language model alignment is a critical step in training modern generative language models.Alignment targets to improve win rate of a sample from the aligned model against the base model.Today, we are increasingly using inference-time algorithms (e.g., Best-of-$N$ , controlled decoding, tree search) to decode from language models rather than standard sampling. We show that this train/test mismatch makes standard RLHF framework sub-optimal in view of such inference-timemethods. To this end, we propose a framework for inference-aware alignment (InfAlign), whichaims to optimize *inference-time win rate* of the aligned policy against the base model. We prove that for any inference-time decoding procedure, the optimal aligned policy is the solution to the standard RLHF problem with a *transformation* of the reward. This motivates us to provide the calibrate-and-transform RL (InfAlign-CTRL) algorithm to solve this problem, which involves a reward calibration step and a KL-regularized reward maximization step with a transformation of the calibrated reward. For best-of-$N$ sampling and best-of-$N$ jailbreaking, we propose specific transformations offering up to 3-8% improvement on inference-time win rates. Finally, we also show that our proposed reward calibration method is a strong baseline for optimizing standard win rate.
Chat is not available.