Labeled data is expensive. Unlabeled data is everywhere. Self‑supervised learning (SSL) turns the second into the first: it creates its own labels from the data itself. This guide explains SSL in plain English first, then dives into the math (entropy, cross‑entropy, KL, InfoNCE), the major families (masked modeling, contrastive, non‑contrastive), and practical systems like BERT, MAE, SimCLR, CLIP, and wav2vec 2.0. You’ll see minimal NumPy and PyTorch code, diagram placeholders, and a high‑quality video for intuition.
Idea: In supervised learning, humans provide labels (cat vs. dog). In self‑supervised learning, the data provides its own supervision. We hide some part of it and train a model to predict the missing part. Over time, the model must understand structure to do well.
Analogy: Imagine covering random words in a sentence with sticky notes and asking a student to fill them in. To succeed, they must understand vocabulary, grammar, and context. That’s language modeling, the classic SSL pretext task.
SSL connection: When we mask words (or image patches) and predict them, we’re reducing uncertainty (entropy) about the missing part given the visible context. We train models to minimize cross‑entropy between the true distribution and the model’s predictions.
BERT hides ~15% of tokens and asks the model to fill them in. The loss is cross‑entropy over the masked positions:
\( \mathcal{L}_{\mathrm{MLM}} = -\sum_{t \in \mathcal{M}} \log p_\theta(x_t \mid \mathbf{x}_{\setminus \mathcal{M}})\)import numpy as np
# Tiny toy: predict a missing word from its left neighbor via bigram counts
corpus = "the cat sat on the mat the cat ate".split()
vocab = sorted(set(corpus))
ix = {w:i for i,w in enumerate(vocab)}
# Build bigram counts
C = np.zeros((len(vocab), len(vocab)), dtype=np.float32)
for a,b in zip(corpus[:-1], corpus[1:]):
C[ix[a], ix[b]] += 1
# Turn counts into conditional probabilities p(next | prev)
row_sums = C.sum(1, keepdims=True) + 1e-8
P = C / row_sums
def predict_missing(left_word):
if left_word not in ix: return None
return vocab[P[ix[left_word]].argmax()]
print("Vocab:", vocab)
print("Most likely after 'the' ->", predict_missing("the"))
This is deliberately simple, but it shows the intuition: to fill in a blank, learn dependencies from context.
Goal: Map augmented views of the same sample to nearby vectors (positives), and of different samples to distant vectors (negatives). The InfoNCE loss for a query \(q\) and key \(k^+\) against negatives \(\{k_j\}\):
\( \mathcal{L}_{\mathrm{InfoNCE}} = -\log \frac{\exp(\mathrm{sim}(q,k^+)/\tau)} {\sum_{j} \exp(\mathrm{sim}(q,k_j)/\tau)} \)Here \(\mathrm{sim}\) is usually cosine similarity and \(\tau\) is a temperature that controls concentration.
BYOL uses two networks: online (with a predictor) and target (an EMA copy). It brings online‑predictions close to target‑projections of another augmented view. Surprisingly, it avoids trivial collapse (constant output) without explicit negatives, thanks to architecture and the predictor/stop‑grad tricks.
CLIP: Train an image encoder and a text encoder so that matching image–caption pairs have high similarity; non‑matching pairs are low. This aligns modalities in a shared space (zero‑shot classification emerges!).
wav2vec 2.0: SSL for speech. Mask latent audio chunks; context network predicts them; features transfer to ASR with few labels.
# pip install transformers datasets torch
from transformers import BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
# 1) Tokenizer & model
tok = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# 2) Data: use a tiny public text dataset for demo (e.g., wikitext-2-raw-v1)
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
def tok_fn(ex):
return tok(ex["text"], truncation=True, padding="max_length", max_length=128)
tok_ds = ds.map(tok_fn, batched=True, remove_columns=ds["train"].column_names)
# 3) Data collator will randomly mask tokens as BERT's pretext task
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm_probability=0.15)
# 4) Trainer
args = TrainingArguments(
output_dir="./mlm-demo",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
eval_strategy="steps",
logging_steps=200,
save_steps=200,
max_steps=1000 # small demo
)
trainer = Trainer(model=model, args=args,
data_collator=collator,
train_dataset=tok_ds["train"].select(range(5000)), # small subset
eval_dataset=tok_ds["validation"].select(range(1000)))
trainer.train()
import torch, torch.nn as nn, torch.nn.functional as F
class Projection(nn.Module):
def __init__(self, dim, proj=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim), nn.ReLU(inplace=True),
nn.Linear(dim, proj))
def forward(self, x): return self.net(x)
def nt_xent(z1, z2, tau=0.2):
# z1, z2: [B, D] normalized
B = z1.size(0)
z = torch.cat([z1, z2], dim=0) # [2B, D]
sim = torch.matmul(z, z.t()) / tau # [2B, 2B]
mask = torch.eye(2*B, dtype=torch.bool, device=z.device)
sim.masked_fill_(mask, -9e15) # remove self-similarity
# positives: (i, i+B) and (i+B, i)
pos = torch.cat([torch.arange(B, 2*B), torch.arange(0, B)]).to(z.device)
labels = pos
loss = F.cross_entropy(sim, labels)
return loss
# encoder = YourBackbone()
# proj = Projection(dim=encoder_out)
# for (x1, x2) in loader: # two augmented views per image
# h1, h2 = encoder(x1), encoder(x2)
# z1, z2 = F.normalize(proj(h1), dim=1), F.normalize(proj(h2), dim=1)
# loss = nt_xent(z1, z2)
Because SSL scales on unlabeled data, curation is crucial: bias, privacy, license compliance, and content filtering all matter. Representations inherit the statistics of what they see.