Poster
In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention
Jianliang He · Xintian Pan · Siyu Chen · Zhuoran Yang
East Exhibition Hall A-B #E-3208
We study how multi-head softmax attention models are trained to perform in-context learning on linear data. Through extensive empirical experiments and rigorous theoretical analysis, we demystify the emergence of elegant attention patterns: a diagonal and homogeneous pattern in the key-query weights, and a last-entry-only and zero-sum pattern in the output-value weights. Remarkably, these patterns consistently appear from gradient-based training starting from random initialization. Our analysis reveals that such emergent structures enable multi-head attention to approximately implement a debiased gradient descent predictor --- one that outperforms single-head attention and nearly achieves Bayesian optimality up to proportional factor. We also extend our study to scenarios with anisotropic covariates and multi-task linear regression. Our results reveal that in-context learning ability emerges from the trained transformer as an aggregated effect of its architecture and the underlying data distribution, paving the way for deeper understanding and broader applications of in-context learning.
Many AI models can perform in‐context learning, but it’s been a mystery how complex transformer networks actually do this. We tackled the problem by training a standard multi‐head softmax transformer on a simple linear regression task and studying both its learned parameters and its training process. We find that each attention head learns a neat diagonal pattern in one set of weights and a zero‐sum pattern in another, emerging reliably from random starts. These structures let the transformer perform a one‐step, debiased form of gradient descent, so it predicts almost as well as the best possible (Bayesian) estimator and beats single‐head versions. Our work reveals exactly how architectural choices and data shape in‐context learning, paving the way for more robust, adaptable AI.