Poster
Avoiding spurious sharpness minimization broadens applicability of SAM
Sidak Pal Singh · Hossein Mobahi · Atish Agarwala · Yann Nicolas Dauphin
East Exhibition Hall A-B #E-3504
Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance --- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics --- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs)
AI models can be "brittle"—small, insignificant changes to an input can cause them to make big mistakes. A popular technique called Sharpness Aware Minimization (SAM) helps fix this for vision-based AI by training them to find more stable and robust solutions, similar to a student who learns a concept deeply rather than just memorizing facts.We discovered that this technique backfires for large language models (LLMs). Instead of improving the model's understanding of language, SAM gets distracted by superficial details of the model's internal predictions.We developed a new approach, Functional SAM, that corrects this. Our method guides the model to focus on the overall function and meaning of its outputs (or answers), ignoring the internal distractions. This leads to more accurate and reliable language models at all scales, achieving better performance without requiring extra computation.