Poster
On the Out-of-Distribution Generalization of Self-Supervised Learning
Wenwen Qiang · Jingyao Wang · Zeen Song · Jiangmeng Li · Changwen Zheng
East Exhibition Hall A-B #E-2212
In this paper, we focus on the out-of-distribution (OOD) generalization of self-supervised learning (SSL). By analyzing the mini-batch construction during the SSL training phase, we first give one plausible explanation for SSL having OOD generalization. Then, from the perspective of data generation and causal inference, we analyze and conclude that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization. To address this issue, we propose a post-intervention distribution (PID) grounded in the Structural Causal Model. PID offers a scenario where the spurious variable and label variable is mutually independent. Besides, we demonstrate that if each mini-batch during SSL training satisfies PID, the resulting SSL model can achieve optimal worst-case OOD performance. This motivates us to develop a batch sampling strategy that enforces PID constraints through the learning of a latent variable model. Through theoretical analysis, we demonstrate the identifiability of the latent variable model and validate the effectiveness of the proposed sampling strategy. Experiments conducted on various downstream OOD tasks demonstrate the effectiveness of the proposed sampling strategy.
In machine learning, models often struggle when faced with data different from their training examples, a challenge known as out-of-distribution (OOD) generalization. We explored how self-supervised learning (SSL)—a method where models learn from unlabeled data—handles this challenge. We first investigated why SSL models sometimes perform well on OOD tasks and found that the way training examples are grouped (or batched) might explain this ability. However, we also identified a key issue: SSL can inadvertently learn irrelevant relationships (called spurious correlations) from the training data, making models less reliable on new, unseen examples. To solve this, we introduced a novel technique called post-intervention distribution (PID), based on causal modeling, which ensures the training data batches don't include misleading correlations. We then created a practical method to select training batches that satisfy this PID condition. Our theoretical and experimental results confirm that this method significantly improves SSL’s performance.