Activations and Gradients#
Introduction#
Training neural networks involves computation across millions (or billions) of weights and activations. This process can be fragile. In this notebook, we attach hooks in order to deep neural nets to analyze the statistics of activations and gradients during training, and consider pitfalls when they are improperly scaled. Finally, we introduce layer normalization (LN) [BKH16] which allows stable propagation of activations and gradients across layers. This makes training deep networks so much easier (e.g. without a lot of hyperparameter tuning).
In the appendix, we consider the effects of the choice of activation function on training dynamics. Finally, we discuss rank collapse for neural network layers where the dimensionality of the output space of each layer degenerates with depth. We find empirically that batch normalization [IS15] prevents rank collapse in MLPs, while LN does not. This is consistent with [DKB+20] and [DCL21].
Preliminaries#
import math
import torch
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib_inline import backend_inline
DATASET_DIR = Path("./data").absolute()
RANDOM_SEED = 1
DEBUG = False
MATPLOTLIB_FORMAT = "png" if DEBUG else "svg"
def set_seed(s: int):
random.seed(s)
np.random.seed(s)
torch.manual_seed(s)
set_seed(RANDOM_SEED)
warnings.simplefilter(action="ignore")
backend_inline.set_matplotlib_formats(MATPLOTLIB_FORMAT)
Still using the names dataset from the previous notebook:
import os
if not os.path.isfile("./data/surnames_freq_ge_100.csv"):
!wget -O ./data/surnames_freq_ge_100.csv https://raw.githubusercontent.com/particle1331/spanish-names-surnames/master/surnames_freq_ge_100.csv
!wget -O ./data/surnames_freq_ge_20_le_99.csv https://raw.githubusercontent.com/particle1331/spanish-names-surnames/master/surnames_freq_ge_20_le_99.csv
else:
print("Data files already exist.")
col = ["surname", "frequency_first", "frequency_second", "frequency_both"]
df1 = pd.read_csv(DATASET_DIR / "surnames_freq_ge_100.csv", names=col, header=0)
df2 = pd.read_csv(DATASET_DIR / "surnames_freq_ge_20_le_99.csv", names=col, header=0)
Show code cell output
Data files already exist.
FRAC_LIMIT = 0.10 if DEBUG else 1.0
df = pd.concat([df1, df2], axis=0)[['surname']].sample(frac=FRAC_LIMIT)
df['surname'] = df['surname'].map(lambda s: s.lower())
df['surname'] = df['surname'].map(lambda s: s.replace("de la", "dela"))
df['surname'] = df['surname'].map(lambda s: s.replace(" ", "_"))
names = [n for n in df.surname.tolist() if "'" not in n and 'ç' not in n and len(n) >= 2]
df = df[['surname']].dropna().astype(str)
df = df[df.surname.isin(names)]
df.to_csv(DATASET_DIR / 'spanish_surnames.csv', index=False)
df = pd.read_csv(DATASET_DIR / 'spanish_surnames.csv').dropna()
df.head()
surname | |
---|---|
0 | agredano |
1 | poblador |
2 | girba |
3 | rabanales |
4 | yucra |
Datasets#
Defining here the dataset class used in the previous notebook:
import torch
from torch.utils.data import Dataset
class CharDataset(Dataset):
def __init__(self, contexts: list[str], targets: list[str], chars: str):
self.chars = chars
self.ys = targets
self.xs = contexts
self.block_size = len(contexts[0])
self.itos = {i: c for i, c in enumerate(self.chars)}
self.stoi = {c: i for i, c in self.itos.items()}
def get_vocab_size(self):
return len(self.chars)
def __len__(self):
return len(self.xs)
def __getitem__(self, idx):
x = self.encode(self.xs[idx])
y = torch.tensor(self.stoi[self.ys[idx]]).long()
return x, y
def decode(self, x: torch.tensor) -> str:
return "".join([self.itos[c.item()] for c in x])
def encode(self, word: str) -> torch.tensor:
return torch.tensor([self.stoi[c] for c in word]).long()
def build_dataset(names, block_size=3):
"""Build word context -> next char target lists from names."""
xs = [] # context list
ys = [] # target list
for name in names:
context = ["."] * block_size
for c in name + ".":
xs.append(context)
ys.append(c)
context = context[1:] + [c]
chars = sorted(list(set("".join(ys))))
return CharDataset(contexts=xs, targets=ys, chars=chars)
Example dataset with block size 3:
dataset = build_dataset(names, block_size=3)
xs = []
ys = []
for i in range(7):
x, y = dataset[i]
xs.append(x)
ys.append(y)
pd.DataFrame({'x': [x.tolist() for x in xs], 'y': [y.item() for y in ys], 'x_word': ["".join(dataset.decode(x)) for x in xs], 'y_char': [dataset.itos[c.item()] for c in ys]})
x | y | x_word | y_char | |
---|---|---|---|---|
0 | [0, 0, 0] | 2 | ... | a |
1 | [0, 0, 2] | 8 | ..a | g |
2 | [0, 2, 8] | 19 | .ag | r |
3 | [2, 8, 19] | 6 | agr | e |
4 | [8, 19, 6] | 5 | gre | d |
5 | [19, 6, 5] | 2 | red | a |
6 | [6, 5, 2] | 15 | eda | n |
Ideally, we should use a stratified k-fold that ensures character distribution is the same between slips. But too lazy. Here we just partition the names dataset by index to create the validation and train datasets.
SPLIT_RATIO = 0.30
split_point = int(SPLIT_RATIO * len(names))
names_train = names[:split_point]
names_valid = names[split_point:]
train_dataset = build_dataset(names_train, block_size=3)
valid_dataset = build_dataset(names_valid, block_size=3)
len(train_dataset), len(valid_dataset)
(186339, 433815)
Weight initialization#
SGD requires choosing an arbitrary starting point \(\boldsymbol{\Theta}_{\text{init}}.\) Setting all weights to zero or some constant does not work as symmetry in the neurons of the network will make it difficult (if not impossible) to train the model. Hence, setting the weights randomly to break symmetry is a good starting point. However, this is still not enough since the variance of every neuron is additive (again due to symmetry and some assumptions):
where \(n = |\boldsymbol{\mathsf{x}}|.\)
import seaborn as sns
x = torch.randn(1000, 100)
w = torch.randn( 100, 200)
y = x @ w
Show code cell source
plt.figure(figsize=(5, 3))
sns.distplot(x.reshape(-1), color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), color="C1", label='$\mathbf{y}$ (output)');
plt.legend();
Observe that the output \(\boldsymbol{\mathsf{y}} = \boldsymbol{\mathsf{x}}^\top \boldsymbol{\mathsf{w}}\) for normally distributed \(\boldsymbol{\mathsf{x}}\) and \(\boldsymbol{\mathsf{w}}\) starts to spread out. This makes sense since we are adding \(n\) terms where where \(n = |\boldsymbol{\mathsf{x}}|.\) Ideally, we want \(\boldsymbol{\mathsf{y}}\) to have the same standard deviation to maintain stable activation flow. Otherwise, activations will recursively grow at each layer. This becomes increasingly problematic with depth.
Dead neurons. Activations have regions where their gradients saturate, so it is very important to control the range of preactivations. A dead neuron has vanishingly small gradients for most training examples. This means that the weights for this neuron will change very slowly compared to other neurons.
A neuron can be dead at initialization or die in the course of training. For instance, a high learning rate can result in large weights that saturate the neurons similar to the situation above. For dense layers, dead neurons tend to remain dead for a long time since weight gradients are too small to significantly change the existing weights for that neuron.
Show code cell source
activations = {
"tanh": torch.tanh,
"relu": torch.relu,
}
def d(f, x):
"""Compute derivative of activation with respect to x."""
y = f(x)
y.backward(torch.ones_like(x))
return x.grad
def plot_activation(act: str, ax, x):
x.grad = None
f = activations[act]
y = f(x)
df = d(f, x)
x, y = x.detach().numpy(), y.detach().numpy()
# Plotting
ax.plot(x, y, linewidth=2, color='red', label="f(x)")
ax.plot(x, df, linewidth=2, color='black', label="f'(x)")
ax.set_title(act)
ax.grid(linestyle="dashed")
ax.set_ylim(-1.5, 4)
if act == "tanh":
ax.axvspan(-5, -1.8183, -10, 10, color='lightgray')
ax.axvspan(1.8183, 5, -10, 10, color='lightgray')
elif act == "relu":
ax.axvspan(-5, 0, -10, 10, color='lightgray', label="|f'(x)| ≤ 0.1")
ax.legend(loc='lower right')
# Plotting
fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
x = torch.linspace(-5, 5, 1000, requires_grad=True)
for i, act_name in enumerate(activations.keys()):
plot_activation(act_name, ax[i], x) # divmod(m, n) = m // n, m % n
fig.tight_layout()
Figure. Plotting activations and activation gradients. Saturation regions where derivatives are small (here set to 0.1
) are highlighted gray.
Xavier init. [GB10] Since the \(\sigma_{\boldsymbol{\mathsf{y}}} = \sqrt{n} \cdot \sigma_{\boldsymbol{\mathsf{w}}} \, \sigma_{\boldsymbol{\mathsf{x}}}\), one straightforward fix is to initialize the weights \(\boldsymbol{\mathsf{w}}\) with a distribution having \({\sigma_{\boldsymbol{\mathsf{w}}}}=\frac{1}{\sqrt{n}}\) where \(n = |\boldsymbol{\mathsf{x}}|.\) Moreover, we set biases to zero. Note that setting the standard deviation turns out to be equivalent to just scaling the random variable with \(\frac{1}{\sqrt{n}}\) by linearity of expectation:
x = torch.randn(1000, 100)
w = torch.randn( 100, 200) # Xavier normal. For xavier uniform, set w ~ U[-a, a] with a = (3 / n) ** 0.5
y = x @ (w / np.sqrt(100))
Show code cell source
plt.figure(figsize=(5, 3))
sns.distplot(x.reshape(-1), color="C0", label='$\mathbf{x}$ (input)')
sns.distplot(y.reshape(-1), color="C1", label='$\mathbf{y}$ (output)')
plt.ylim(0, 0.6)
plt.legend();
Remark. Proper weight init can also help with the variance of gradients. For dense layers, the backprop equations that relate input and output gradients are linear with \(\boldsymbol{\mathsf{w}}^\top.\) Improper scaling of weights can therefore cause vanishing or exploding gradients for MLPs as recursively multiplying weight matrices can increase or decrease weight gradients exponentially with depth. Hence, we can alternatively sample weights with \(\sigma_{\boldsymbol{\mathsf{w}}} = \frac{1} {\sqrt{n_\text{out}}}.\) Indeed, weight init can be implemented with fan average mode: \(\sigma_{\boldsymbol{\mathsf{w}}} = \sqrt{{2}/{{(n_\text{in} + n_\text{out})}}}.\)
Gain. Note that this scale factor only holds for linear layers. Nonlinear activations squashes its input which compounds as we stack layers in deep networks. The factor \(\mathsf{g}\) such that \({\sigma_{\boldsymbol{\mathsf{y}}}} = \mathsf{g} \cdot {\sigma_{\boldsymbol{\mathsf{x}}}}\) in called gain. This can be introduced as a factor on the standard deviation of the weights, i.e. setting the parameters of the distribution such that \(\sigma_{\boldsymbol{\mathsf{w}}} = \mathsf{g} \frac{1}{\sqrt{n}}\) where \(n = |\boldsymbol{\mathsf{x}}|\) for some \(\mathsf{g} > 0.\) Typically, weights are sampled from either normal or uniform distributions.
Show code cell source
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
n = 100
x = torch.randn(1000, n)
z = torch.tanh(x) # Note: input of a hidden layer is an activation
w = torch.randn(n, 200) / n ** 0.5
y = z @ w
sns.distplot(x.reshape(-1), ax=ax[0], color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), ax=ax[0], color="C1", label='$\mathbf{y}$ (output)');
ax[0].set_title(r"$w \sim N(0, \sigma^2),$ $\sigma = n ^{-0.5}$", size=12)
ax[0].set_ylim(0, 0.8)
g = 5 / 3 # tanh gain
x = torch.randn(1000, n)
z = torch.tanh(x)
w = torch.randn(n, 200) * g / n ** 0.5
y = z @ w
sns.distplot(x.reshape(-1), ax=ax[1], color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), ax=ax[1], color="C1", label='$\mathbf{y}$ (output)');
ax[1].set_title(r"$w \sim N(0, \sigma^2),$ $\sigma = \frac{5}{3} {n}^{-0.5}$", size=12)
ax[1].set_ylim(0, 0.8)
ax[1].legend()
fig.tight_layout();
Kaiming init. Note that \(\mathsf{g}\) for an activation is usually obtained using some heuristic or by performing empirical tests, e.g. [HZRS15b] sets \(\mathsf{g} = \sqrt{2}\) since half of ReLU outputs are zero in the calculation of variance (see proof). See also calculate_gain
for other commonly used activations.
PyTorch default for nn.Linear
is essentially \(\mathsf{g} = \frac{1}{\sqrt{3}}\) with Xavier uniform fan-in initialization. Note that the standard deviation of the uniform distribution \(U[-a, a]\) is specified by the bound \(a\), i.e. \(\sigma = \frac{a}{\sqrt{3}}.\) The Pytorch implementation sets \(a = \frac{1}{\sqrt{n}}\) getting an effective gain of \(\frac{1}{\sqrt{3}}.\)
import torch.nn as nn
fan_in = 100
lin = nn.Linear(fan_in, 1234)
w = lin.weight.data
w.std(), 1 / (np.sqrt(3) * np.sqrt(fan_in))
(tensor(0.0576), 0.05773502691896258)
Remark. The formula for \(\sigma\) of \(U[-a, a]\) also explains the form of nn.init.kaiming_uniform_
.
Logits. Note that the Xavier init heuristic also applies to the weights of the logits layer. It essentially acts as softmax temperature:
import torch.nn.functional as F
x = torch.randn(1, 30)
w = torch.randn(30, 10)
y0 = F.softmax(x @ w)
y1 = F.softmax(x @ (w / np.sqrt(30)))
Show code cell source
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
ax[0].set_title(r"$\sigma = 1$", size=11)
ax[1].set_title(r"$\sigma = {1}/{\sqrt{30}}$", size=11)
ax[0].set_xlabel("class index")
ax[1].set_xlabel("class index")
ax[0].bar(range(10), y0[0])
ax[1].bar(range(10), y1[0]);
Training with hooks#
In our experiments, we use total steps per run instead of epochs. The following class wraps a data loader so that the iterator resets whenever the number of steps exceed one epoch (i.e. whenever a single pass over the entire training dataset has been reached):
class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iterator = iter(self.data_loader)
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.data_iterator)
except StopIteration:
self.data_iterator = iter(self.data_loader)
batch = next(self.data_iterator)
return batch
The Trainer
class below streamlines model training (see previously notebook). Note that the class is augmented with hooks which we will use to obtain activation and gradient statistics during forward and backward passes. The hooks are attached to the model before training, then removed after. We skip validation during training since our current concern is training dynamics (not generalization).
from tqdm.notebook import tqdm
from contextlib import contextmanager
from torch.utils.data import DataLoader
DEVICE = "mps"
@contextmanager
def eval_context(model):
"""Temporarily set to eval mode inside context."""
state = model.training
model.eval()
try:
yield
finally:
model.train(state)
class Trainer:
def __init__(self,
model, optim, loss_fn,
scheduler=None,
callbacks=[],
device=DEVICE, verbose=True):
self.model = model
self.optim = optim
self.device = device
self.loss_fn = loss_fn
self.logs = {"train": {"loss": []}}
self.verbose = verbose
self.scheduler = scheduler
self.callbacks = callbacks
self.forward_hooks = []
self.backward_hooks = []
def __call__(self, x):
return self.model(x.to(self.device))
def forward(self, batch):
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
return self.model(x), y
def train_step(self, batch):
preds, y = self.forward(batch)
loss = self.loss_fn(preds, y)
loss.backward()
self.optim.step()
self.optim.zero_grad()
return loss
@torch.inference_mode()
def valid_step(self, batch):
preds, y = self.forward(batch)
loss = self.loss_fn(preds, y, reduction="sum")
return loss
def register_hooks(self, # !!!
fwd_hooks=[],
bwd_hooks=[],
layer_types=(nn.Linear, nn.Conv2d)):
for _, module in self.model.named_modules():
if isinstance(module, layer_types):
for f in fwd_hooks:
h = module.register_forward_hook(f)
self.forward_hooks.append(h)
for g in bwd_hooks:
h = module.register_backward_hook(g)
self.backward_hooks.append(h)
def remove_hooks(self):
for h in self.forward_hooks + self.backward_hooks:
h.remove()
def run(self, steps, train_loader):
train_loader = InfiniteDataLoader(train_loader)
for _ in tqdm(range(steps)):
# optim and lr step
batch = next(train_loader)
loss = self.train_step(batch)
if self.scheduler:
self.scheduler.step()
# step callbacks
for callback in self.callbacks:
callback()
# logs @ train step
self.logs["train"]["loss"].append(loss.item())
def evaluate(self, data_loader):
with eval_context(self.model):
valid_loss = 0.0
for batch in data_loader:
loss = self.valid_step(batch)
valid_loss += loss.item()
return {"loss": valid_loss / len(data_loader.dataset)}
Forward hooks are expected to look like:
hook(module, input, output) -> None or modified output
On the other hand, backward hook functions should look like:
hook(module, grad_in, grad_out) -> None or modified grad_in
The following classes are for handling outputs of the hook functions:
class HookHandler:
def __init__(self):
self.records = {}
def hook_fn(self, module, input, output):
raise NotImplementedError
class OutputStats(HookHandler):
def hook_fn(self, module, input, output):
self.records[module] = self.records.get(module, [])
self.records[module].append(output)
class WeightGradientStats(HookHandler):
def hook_fn(self, module, grad_in, grad_out):
self.records[module] = self.records.get(module, [])
for t in grad_in:
if t is not None and t.shape == module.weight.data.shape:
self.records[module].append(t)
class ActivationGradientStats(HookHandler):
def hook_fn(self, module, grad_in, grad_out):
self.records[module] = self.records.get(module, [])
for t in grad_in:
if t is not None:
self.records[module].append(grad_in[0])
class DyingReluStats(HookHandler):
def __init__(self, sat_threshold=1e-8, frac_threshold=0.95):
super().__init__()
self.sat_threshold = sat_threshold
self.frac_threshold = frac_threshold
def hook_fn(self, module, input, output):
self.records[module] = self.records.get(module, [])
B = output.shape[0]
count_dead = ((output < self.sat_threshold).float().sum(dim=0) / B) > self.frac_threshold
frac_dead = count_dead.float().mean()
self.records[module].append(frac_dead.item())
Setting up our experiment harness:
def get_model(vocab_size: int,
block_size: int,
emb_size: int,
width: int,
num_linear: int):
layers = [
nn.Embedding(vocab_size, emb_size),
nn.Flatten(),
nn.Linear(block_size * emb_size, width),
nn.ReLU()
]
for _ in range(num_linear):
layers.append(nn.Linear(width, width))
layers.append(nn.ReLU())
layers.append(nn.Linear(width, vocab_size))
model = nn.Sequential(*layers)
return model
def init_model(model, fn):
"""Expected `fn` modifies linear layer weights in-place. See
https://pytorch.org/docs/stable/nn.init.html for init functions."""
for m in model.modules():
if isinstance(m, nn.Linear):
fn(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
if isinstance(m, nn.Embedding):
g = torch.Generator().manual_seed(RANDOM_SEED)
nn.init.normal_(m.weight, generator=g)
return model
def run(model,
train_dataset: Dataset,
batch_size: int,
hooks: list[dict],
steps: int,
lr: float) -> Trainer:
# Setup optimization and data loader
set_seed(RANDOM_SEED)
loss_fn = F.cross_entropy
model = model.to(DEVICE)
optim = torch.optim.SGD(model.parameters(), lr=lr)
trainer = Trainer(model, optim, loss_fn, device=DEVICE)
train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True)
# Finally, training...
try:
for hook_dict in hooks:
trainer.register_hooks(
fwd_hooks=[h.hook_fn for h in hook_dict["fwd_hooks"]],
bwd_hooks=[h.hook_fn for h in hook_dict["bwd_hooks"]],
layer_types=hook_dict["layer_types"]
)
trainer.run(steps=steps, train_loader=train_loader)
except Exception as e:
print(e)
finally:
trainer.remove_hooks()
return trainer
The task and model used here is from the previous notebook. In the init_model
function, we iterate over the layers and apply a function from the nn.init
library to modify layer weights in-place. The xavier_uniform_
implementation uses fan average mode. This is applied to linear layer weights, while biases are set to zero. Finally, the embedding layer is initialized with weights from a Gaussian distribution.
def init_hooks() -> list[dict]:
global dead_relu_stats
global relu_outs_stats
global relu_grad_stats
global linear_outs_stats
global linear_grad_stats
dead_relu_stats = DyingReluStats(sat_threshold=1e-8, frac_threshold=0.95)
relu_outs_stats = OutputStats()
relu_grad_stats = ActivationGradientStats()
linear_outs_stats = OutputStats()
linear_grad_stats = WeightGradientStats()
return [
{
"fwd_hooks": [linear_outs_stats],
"bwd_hooks": [linear_grad_stats],
"layer_types": (nn.Linear,)
},
{
"fwd_hooks": [dead_relu_stats, relu_outs_stats],
"bwd_hooks": [relu_grad_stats],
"layer_types": (nn.ReLU,)
}
]
model_params = {
"vocab_size": 29,
"block_size": 3,
"emb_size": 10,
"width": 30,
"num_linear": 2
}
train_params = {
"steps": 1000,
"batch_size": 32,
"lr": 0.1
}
g = lambda: torch.Generator().manual_seed(RANDOM_SEED)
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
model = init_model(get_model(**model_params), init_fn)
hooks = init_hooks()
trainer = run(model, train_dataset, **train_params, hooks=hooks)
Show code cell source
def moving_avg(a, window=0.05):
ma = []
steps = len(a)
w = int(window * steps)
for s in range(steps):
ma.append(np.mean(a[-w + s + 1: s + 1]))
return ma
def plot_training_loss(trainer, window=0.05, figsize=(5, 3)):
train_loss = trainer.logs["train"]["loss"]
train_loss_ma = moving_avg(train_loss, window)
plt.figure(figsize=figsize)
plt.plot(trainer.logs["train"]["loss"], alpha=0.5, label="train")
plt.plot(train_loss_ma, color="C0", label="train (MA)")
plt.ylabel("loss")
plt.xlabel("steps")
plt.ylim(min(trainer.logs["train"]["loss"]) - 0.2, min(max(trainer.logs["train"]["loss"]) + 0.2, 5.0))
plt.grid(linestyle="dotted", alpha=0.6)
plt.legend();
plot_training_loss(trainer)
Histograms#
The hooks stored activation and gradient statistics during training which we now visualize:
Show code cell source
def plot_training_histograms(stats: list[HookHandler],
bins: int = 100,
num_stds: int = 4,
cmaps: list[str] = None,
labels: list[str] = None,
figsize=(6, 2),
aspect: int = 5):
rows = len(stats)
cols = len(stats[0].records.keys())
cmaps = ["viridis"] * rows if cmaps is None else cmaps
H, W = figsize
fig, ax = plt.subplots(rows, cols, figsize=(H * rows, W * cols))
# estimate a good range per row (i.e. per handler)
# same range => show vanishing / exploding over layers
h = {}
acts = {}
for i in range(rows):
for j, m in enumerate(stats[0].records.keys()):
if m not in stats[i].records:
continue
acts[i, j] = torch.stack([t.reshape(-1) for t in stats[i].records[m]], dim=0).cpu()
all_data = torch.concat([acts[i, j].reshape(-1) for j in range(len(stats[0].records))])
h[i] = all_data.std().item() * num_stds
# calculate histograms
for i in range(rows):
for j, m in enumerate(stats[0].records.keys()):
if m not in stats[i].records:
continue
with torch.no_grad():
# Add one count -> log = more colorful, note: log 1 = 0.
# total count (i.e. sum per col) == batch_size * num_neurons, .: normalized
hists = [(torch.histc(acts[i, j][t], bins=bins, min=-h[i], max=h[i]) + 1).log().cpu().numpy().reshape(1, -1) for t in range(acts[i, j].shape[0])]
# histogram image
axf = lambda i, j: ax[j] if rows == 1 else ax[i, j]
axf(i, j).set_title(m.__class__.__name__ + "." + str(j), size=10)
axf(i, j).imshow(np.flip(np.concatenate(hists, axis=0).T, axis=0), aspect='auto', cmap=cmaps[i]) # (!)
axf(i, j).set_aspect(aspect) # transpose => positive vals = down
axf(i, j).set_yticks([]) # flip => positive vals = up
if j == 0 and labels is not None:
axf(i, j).set_ylabel(f"{labels[i]}\n[-{h[i]:.1e}, {h[i]:.1e}]")
else:
axf(i, j).set_ylabel(f"[-{h[i]:.1e}, {h[i]:.1e}]")
axf(i, j).set_xlabel("steps")
fig.tight_layout()
plt.show()
Show code cell source
def plot_training_stats(stats: list[HookHandler],
labels: list[str] = None,
colors: list[str] = None,
figsize=(6, 2)):
rows = len(stats)
cols = len(stats[0].records.keys())
colors = [f"C{i}" for i in range(rows)] if colors is None else colors
H, W = figsize
fig, ax = plt.subplots(rows, cols, figsize=(H * rows, W * cols))
# calculate mean and std
axf = lambda i, j: ax[j] if rows == 1 else ax[i, j]
for i in range(rows):
h = 0 # y lim
for j, m in enumerate(stats[0].records.keys()):
if m not in stats[i].records:
continue
with torch.no_grad():
r = torch.stack([t for t in stats[i].records[m]])
r = r.view(r.shape[0], -1).cpu().numpy()
u = np.percentile(r, q=50, axis=1) # [a, b] contains 68% of data
a = np.percentile(r, q=16, axis=1) # 50 - 34 = 16
b = np.percentile(r, q=84, axis=1) # 50 + 34 = 84
h = max(h, max(np.abs(a).mean(), np.abs(b).mean()).item())
# plot
axf(i, j).set_title(m.__class__.__name__ + "." + str(j), size=10)
axf(i, j).plot(a, color=colors[i])
axf(i, j).plot(b, color=colors[i])
axf(i, j).fill_between(np.arange(len(u)), a, b, color=colors[i], alpha=0.3)
axf(i, j).plot(u, color="black", label=f"$\mu$", alpha=0.8)
axf(i, j).set_xlabel("steps")
axf(i, j).grid(linestyle="dashed", alpha=0.6)
if j == 0 and labels is not None:
axf(i, j).set_ylabel(labels[i])
# set y-lims
for j, m in enumerate(stats[0].records.keys()):
if m in stats[i].records:
axf(i, j).set_ylim(-2 * h, 2 * h)
fig.tight_layout()
plt.show()
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.5))
loss (last 100 steps): 2.557870314121246
Figure. First plot shows activation and gradient histograms over time. Range of values shown is \(\mu \pm 4 \sigma\) over all time steps, units, and layers. Considering each layer separately, the network is training well. But between layers we see that the gradient distribution changes. Since the distributions are not symmetric, the second plot shows the 16-50-84th percentiles of samples at each step.
# can also check relu stats
plot_training_histograms([relu_outs_stats, relu_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(5.0, 1.5))
plot_training_stats([relu_outs_stats, relu_grad_stats], labels=["Outputs", "Gradients"], figsize=(5, 1.5))
Show code cell output
Dying units#
Good initialization prevent dying units:
Show code cell source
def plot_dying_relus(trainer, stats):
module_list = [m for m in trainer.model.modules() if isinstance(m, nn.ReLU)]
plt.figure(figsize=(6, 3))
for i, m in enumerate(module_list):
if isinstance(m, nn.ReLU):
moving_avg = []
steps = len(stats.records[m])
w = int(0.05 * steps)
for s in range(steps):
moving_avg.append(np.mean(stats.records[m][-w + s + 1: s + 1]))
plt.plot(moving_avg, label=f"ReLU.{i}")
plt.legend()
plt.grid(linestyle="dashed", alpha=0.6)
plt.title(f"sat_threshold={stats.sat_threshold:.1e}, frac_treshold={stats.frac_threshold:.2f}", size=10)
plt.xlabel("step")
plt.ylabel("frac dead relu (MA)")
plt.ylim(-0.1, 1.1)
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.557870314121246
Increasing LR from 0.1
to 2.0
.
As discussed above, this can push weights to large values where the activations saturate,
knocking off the model to a flat region in the loss surface where any sample
of the data does not affect the shape of the loss surface. Hence, the model barely trains with further SGD steps.
from copy import copy
def _update(params, **kwargs) -> dict:
"""Helper to make stateless updates to params."""
params = copy(params)
for key, val in kwargs.items():
params[key] = val
return params
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
hooks = init_hooks()
model = init_model(get_model(**model_params), init_fn)
trainer = run(model, train_dataset, **_update(train_params, lr=2.0), hooks=hooks)
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.813358087539673
plot_training_loss(trainer)
Show code cell output
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Show code cell output
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.2))
Layer normalization#
Recall that for weight initialization we have to specify the gain for each activation. It would be nice to learn the gain instead of setting it heuristically, or deriving it from equations. Moreover, from the above pathological examples, we find cases where initialization cannot help past the early steps (e.g. when we have large learning rates). To get stable activations during training, we want to normalize layer preactivations dynamically.
This can be done using normalization layers. In particular, we look at layer normalization [BKH16]. The following equations define the action of LayerNorm on an input \(\boldsymbol{\mathsf{z}}\) for one instance fed into the network:
where \(n = |\boldsymbol{\mathsf{z}}|.\) This makes the outputs of each neuron lie in the same range. Consequently, this also helps control the magnitude of weights and weight gradients. Here we introduce trainable affine parameters \(\boldsymbol{\gamma}\) and \(\boldsymbol{\beta}\) that are applied per unit. This allows the output distributions to scale and shift, otherwise network expressiveness is degraded with preactivations mostly in \([-1, 1]\). It can be easily shown that LN is invariant to rescaling and recentering of the weight and data matrices.
n = 5
layernorm = nn.LayerNorm(n)
x = torch.randn(3, n)
u = x.mean(dim=1, keepdim=True)
v = x.var(dim=1, unbiased=False, keepdim=True)
eps = 0.00001
gamma = torch.ones(n) # params init
beta = torch.zeros(n)
gamma * (x - u) / torch.sqrt(v + eps) + beta
tensor([[ 0.4980, 1.2870, -0.2261, 0.1768, -1.7358],
[ 1.5455, -1.1316, -0.2491, -0.8746, 0.7098],
[-0.0341, -1.5814, 1.2074, 0.8957, -0.4876]])
layernorm(x)
tensor([[ 0.4980, 1.2870, -0.2261, 0.1768, -1.7358],
[ 1.5455, -1.1316, -0.2491, -0.8746, 0.7098],
[-0.0341, -1.5814, 1.2074, 0.8957, -0.4876]],
grad_fn=<NativeLayerNormBackward0>)
Remark. The affine parameters have the same shape as normalized_shape
(i.e. main argument of LayerNorm
). This is the
as the last D
dimensions of the input tensor that are normalized over, where D
is the rank of normalized_shape
.
Immediate effects#
Updating our architecture with LN layers which are added before the activation function. Hence, preactivations are normalized around the active region of the activation, followed by a unit-wise shift and scale. Note that we also add normalization on logits:
def get_model(vocab_size: int,
block_size: int,
emb_size: int,
width: int,
num_linear: int,
layernorm: bool = False):
layers = [
nn.Embedding(vocab_size, emb_size),
nn.Flatten(),
nn.Linear(block_size * emb_size, width)
]
layers.append(nn.LayerNorm(width)) if layernorm else ""
layers.append(nn.ReLU())
for _ in range(num_linear):
layers.append(nn.Linear(width, width))
layers.append(nn.LayerNorm(width)) if layernorm else ""
layers.append(nn.ReLU())
layers.append(nn.Linear(width, vocab_size))
layers.append(nn.LayerNorm(vocab_size)) if layernorm else ""
model = nn.Sequential(*layers)
return model
Initialization. Having layer norm makes training less sensitive to weight initialization:
losses = {0.30: {}, 1.41: {}, 3.0: {}} # sqrt(2) = 1.414...
for gain in [0.30, 1.41, 3.0]:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=gain, generator=g())
model_ln = init_model(get_model(**model_params, layernorm=True), init_fn)
model_nf = init_model(get_model(**model_params), init_fn)
trainer_ln = run(model_ln, train_dataset, hooks=[], **train_params) # no hooks
trainer_nf = run(model_nf, train_dataset, hooks=[], **train_params) # NF = norm free
losses[gain]["ln"] = trainer_ln.logs["train"]["loss"]
losses[gain]["nf"] = trainer_nf.logs["train"]["loss"]
Observe that networks with LN layers start with better loss values:
Show code cell source
plt.figure(figsize=(6, 4))
window = 0.03
for i, gain in enumerate(losses.keys()):
train_loss = losses[gain]["nf"]
plt.plot(moving_avg(train_loss, window), color=f"C{i}", label=f"gain={gain:.2f}", alpha=0.6)
for i, gain in enumerate(losses.keys()):
train_loss = losses[gain]["ln"]
plt.plot(moving_avg(train_loss, window), color=f"C{i}", label=f"gain={gain:.2f} (LN)", linewidth=2.0)
plt.ylabel("loss")
plt.xlabel("steps")
plt.grid(alpha=0.5, linestyle="dotted")
plt.ylim(2.2, 4.5)
plt.legend();
LN can be thought of as automatically and dynamically finding per unit gain:
# Using suboptimal PyTorch default gain:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=1/np.sqrt(3), generator=g())
model = init_model(get_model(layernorm=True, **model_params), init_fn)
hooks = init_hooks()
trainer = run(model, train_dataset, **train_params, hooks=hooks)
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Show code cell output
Note that shrinking activation is immediately corrected (more gradual without LN):
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.5))
loss (last 100 steps): 2.5370990657806396
Large learning rate. LN allows the network to tolerate larger weight updates:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
hooks = init_hooks()
model = init_model(get_model(layernorm=True, **model_params), init_fn)
trainer = run(model, train_dataset, **_update(train_params, lr=2.0), hooks=hooks)
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.576110579967499
Looks better behaved compared to the previous plot (unnormalized with lr=2.0
):
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Show code cell output
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.2))
Gradient normalization#
If we look at the above plot for the gradient distribution over time, LN seems to help not only with forward propagation (i.e. the model still making meaningful predictions despite large weights), but also with backward propagation since the gradient distributions look better compared to the unnormalized network. This makes sense, since the preactivation neurons \(\boldsymbol{\mathsf{z}}\) are now interacting (Fig. 48), these neurons and consequently their weights are then forced to co-adapt. The exact dependency can be seen in the BP equations for LN which we derive below.
Let \(\boldsymbol{\mathsf{y}} = \text{LN}_{\gamma, \beta}(\boldsymbol{\mathsf{z}}).\) Parameters are straightforward:
Next, we calculate the layer stats:
Pushing these gradients to the input node:
This centers the output gradients and the middle term removes the component of the output gradient along the standardized preactivation \(\hat{\boldsymbol{\mathsf{z}}}\) in the layer dimension. Observe that \(\sum_{j} \frac{\partial \mathcal{L}}{\partial {\mathsf{z}}_j} = 0\) so that the mean of downstream gradients is zero. Moreover, it can be shown (see [XSZ+19]) that:
A larger variance of preactivations result in a reduction in the variance of gradients, and vice-versa. This centering and rescaling of gradients can be thought of as gradient normalization. This extends to weight gradients since \(\frac{\partial\mathcal{L}}{\partial \boldsymbol{\mathsf{w}}} = \boldsymbol{\mathsf{x}}^\top \frac{\partial\mathcal{L}}{\partial \boldsymbol{\mathsf{z}}}\) where both factors are centered with controlled variance.
Gradient checking. Verifying the formulas empirically using autograd
. Here we asumme x
is some output of a previous activation layer. Then, z
is the current preactivation which we pass through LN to get y
. The outputs y
are passed as logits to calculate cross-entropy, with respect to some random labels, which we use to backpropagate.
n = 128
m = 256
# Forward pass graph
B = 10000
x = torch.tanh(torch.randn((B, m), requires_grad=True))
w = 3 * torch.randn((m, n), requires_grad=True)
z = x @ w
# Nudge params to make sure 1 and 0 do not hide bugs
ε = 1e-5
γ = torch.randn(n).requires_grad_(True)
β = torch.randn(n).requires_grad_(True)
v = z.var(1, keepdim=True, unbiased=False) # (!) Hours were spent finding an error in the calculation. Forgot default: unbiased=True
μ = z.mean(1, keepdim=True)
ẑ = (z - μ) / torch.sqrt(v + ε)
y = γ * ẑ + β
for u in [z, y, x, w, γ, β, v, μ]:
u.retain_grad()
t = torch.randint(0, n, size=(B,))
loss = -torch.log(F.softmax(y, dim=1)[range(B), t]).sum() / B
loss.backward()
Computing gradients by hand:
dy = y.grad
dγ = (dy * ẑ).sum(0)
dβ = dy.sum(0)
dv = -(γ * dy * (z - μ)).sum(1, keepdim=True) * 0.5 * (v + ε) ** -1.5
dμ = -(γ * dy).sum(1, keepdim=True) / (v + ε) ** 0.5
dz = γ * dy - (1 / n) * ẑ * (γ * dy * ẑ).sum(1, keepdim=True) - (1 / n) * (γ * dy).sum(1, keepdim=True)
dz *= (1 / torch.sqrt(v + ε))
dx = dz @ w.T
dw = x.T @ dz
Gradient check:
Show code cell source
def compare(name, dt, t):
exact = torch.all(dt == t.grad).item()
maxdiff = (dt - t.grad).abs().max().item()
approx = torch.allclose(dt, t.grad, rtol=1e-7)
print(f'{name} | exact: {str(exact):5s} | approx: {str(approx):5s} | maxdiff: {maxdiff:.2e}')
compare('y', dy, y)
compare('γ', dγ, γ)
compare('β', dβ, β)
compare('v', dv, v)
compare('μ', dμ, μ)
compare('z', dz, z)
compare('w', dw, w)
compare('x', dx, x)
y | exact: True | approx: True | maxdiff: 0.00e+00
γ | exact: True | approx: True | maxdiff: 0.00e+00
β | exact: True | approx: True | maxdiff: 0.00e+00
v | exact: False | approx: True | maxdiff: 8.53e-14
μ | exact: False | approx: True | maxdiff: 3.64e-12
z | exact: False | approx: True | maxdiff: 1.82e-12
w | exact: False | approx: True | maxdiff: 2.04e-10
x | exact: False | approx: True | maxdiff: 2.18e-11
Demonstrating the variance relation. Note the factor \(\gamma_j\) is obtained by dimensional analysis since the proof in [XSZ+19] sets \(\gamma_j = 1\) for convenience. Moreover, \(\gamma_j \frac{\partial \mathcal L}{\partial {\mathsf{y}}_j}\) seems to always appear together in the equations. But this seems to be correct:
(dz.var(1) <= 1 / (v.view(-1) + ε) * (γ * dy).var(1)).float().mean().item()
0.9995999932289124
Observe that the bound is super tight:
gap = 1 / (v.view(-1) + ε) * (γ * dy).var(1) - dz.var(1)
gap.min().mean().item(), gap.max().mean().item()
(-2.710505431213761e-20, 5.109291895816215e-14)
Remark. The above code uses a lot of broadcasting. The rule is that the leftmost dimension of the lower rank tensor is padded with 1 whenever the ranks of the tensors are unequal. If the shape does not match between two equal ranked tensors, the dimension with 1 is stretched to match the other shape. If any two non-unit dimensions are unequal, then an error is raised.
For example, z - z.mean(1)
fails. The shapes are [B, n]
and [B,]
, respectively. The latter becomes [1, B]
after the first rule, then [B, B]
. Another one that gives a bug (i.e. a mistake) is (1 / v) * (γ * dy).var(1)
since [B, 1]
and [B,]
respectively, so that the latter becomes [1, B]
. Hence, the final shape of the two tensors is [B, B]
!
On the other hand, γ * dy
works since [n,]
and [B, n]
. The former becomes [1, n]
which becomes [B, n]
.
Weight update ratio#
Another thing we can look at is the magnitude of weight updates relative to the weights. That is, the ratio \(\zeta_k = \frac{\|\lambda {\nabla}{\boldsymbol{\mathsf{w}}_k}\|}{\|\boldsymbol{\mathsf{w}}_k\|}\) where \({\nabla}{\boldsymbol{\mathsf{w}}_k} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{w}}_k}.\) A low overall \(\zeta_k\) means that convergence is slow, whereas a high \(\zeta_k\) allows the network to quickly unlearn good weight configurations.
Having layers with varying \(\zeta_k\) is not good for training as weights learned by slower layers can become outdated as faster layers learn new weights, unless the layers are doing different things (e.g. embeddings and logits). This can be seen in the plot below.
class WeightUpdateRatio(HookHandler):
def __init__(self, lr: float):
super().__init__()
self.lr = lr
def hook_fn(self, module, grad_in, grad_out):
self.records[module] = self.records.get(module, [])
for g in grad_in:
if g is not None and g.shape == module.weight.data.shape:
w = module.weight.data
g_norm = torch.linalg.norm(g)
w_norm = torch.linalg.norm(w)
self.records[module].append(self.lr * (g_norm / w_norm).item())
Making the network deeper:
num_linear = 4
lr = 0.6
steps = 2000
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
weight_update_ratios = {}
for ln in ["un", "ln"]:
use_ln = ln == "ln"
model = init_model(get_model(layernorm=use_ln, **_update(model_params, num_linear=num_linear)), init_fn)
weight_update_ratio = WeightUpdateRatio(lr)
hooks = [{"fwd_hooks": [], "bwd_hooks": [weight_update_ratio], "layer_types": (nn.Linear, nn.Embedding)}]
trainer = run(model, train_dataset, **_update(train_params, lr=lr, steps=steps), hooks=hooks)
print(f"[{ln}] loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
weight_update_ratios[ln] = weight_update_ratio
[un] loss (last 100 steps): 2.540506160259247
[ln] loss (last 100 steps): 2.51398802280426
Show code cell source
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
records = weight_update_ratios["un"].records
for j, m in enumerate(records.keys()):
ax[0].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
ax[0].grid(linestyle="dotted", alpha=0.6)
ax[0].set_title("Unnormalized", fontsize=12)
ax[0].set_ylabel(r"$\log(\lambda \| \nabla \mathsf{\mathbf{w}} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
ax[0].set_xlabel("steps")
ax[0].set_ylim(-4, -1.0);
records = weight_update_ratios["ln"].records
for j, m in enumerate(records.keys()):
ax[1].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
ax[1].grid(linestyle="dotted", alpha=0.6)
ax[1].set_title("Normalized (linear + LN)", fontsize=12)
ax[1].set_xlabel("steps")
ax[1].set_ylabel(r"$\log(\lambda \| \nabla \mathsf{\mathbf{w}} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
ax[1].legend()
ax[1].set_ylim(-4, -1.0);
Figure. Model training with +4 hidden layers, Kaiming init, and learning rate of 1. Since we took the logarithm, y=-2
correspond to an update ratio of 0.01
for one mini-batch. This roughly means that it would take 100 steps to completely erase the current weights. The update ratios for the logits and embedding layers are nicely disambiguated: the logits weights train faster while the embeddings train slower. Moreover, notice the the update rates for the hidden layers are well ordered.
Appendix: Activations#
gain = {
"tanh": 5/3,
"relu": math.sqrt(2),
"selu": 1, # https://github.com/pytorch/pytorch/pull/53694#issuecomment-795683782
"gelu": math.sqrt(2), # GELU plot looks like ReLU
}
# Feel free to add to this list
activation_fns = {
"tanh": nn.Tanh,
"selu": nn.SELU,
"relu": nn.ReLU,
"gelu": nn.GELU
}
def d(act, x):
"""Compute derivative of activation with respect to x."""
f = act()
y = f(x)
y.backward(torch.ones_like(x))
return x.grad
# Region x s.t. |f'(x)| <= 0.1:
saturation_fns = {
"tanh": lambda x: x.abs() > 1.8183,
"selu": lambda x: x < -2.8668,
"relu": lambda x: x < 0.0,
"gelu": lambda x: (x < -1.8611) | ((-1.0775 < x) & (x < -0.5534)),
}
Show code cell source
def plot_activation(act: str, ax, x):
x.grad = None
f = activation_fns[act]
y = f()(x)
df = d(f, x)
x, y = x.detach().numpy(), y.detach().numpy()
# Plotting
ax.plot(x, y, linewidth=2, color="red", label="f(x)")
ax.plot(x, df, linewidth=2, color="black", label="f'(x)")
ax.set_title(act)
ax.grid()
ax.set_ylim(-2, 4)
if act == "tanh":
ax.axvspan(-5, -1.8183, -10, 10, color="lightgray")
ax.axvspan(1.8183, 5, -10, 10, color="lightgray")
elif act == "relu":
ax.axvspan(-5, 0, -10, 10, color="lightgray")
elif act == "leaky_relu":
ax.axvspan(-5, 0, -10, 10, color="lightgray")
elif act == "elu":
ax.axvspan(-5, -2.3025, -10, 10, color="lightgray")
elif act == "selu":
ax.axvspan(-5, -2.8668, -10, 10, color="lightgray")
elif act == "gelu":
ax.axvspan(-5, -1.8611, -10, 10, color="lightgray")
ax.axvspan(-1.0775, -0.5534, -10, 10, color="lightgray", label="|f'(x)| ≤ 0.1")
ax.legend(loc="lower right")
# Plotting
rows = math.ceil(len(activation_fns) / 2.0)
fig, ax = plt.subplots(rows, 2, figsize=(10, rows*3.5))
x = torch.linspace(-5, 5, 1000, requires_grad=True)
for i, act_name in enumerate(activation_fns.keys()):
plot_activation(act_name, ax[divmod(i, 2)], x) # divmod(m, n) = m // n, m % n
fig.tight_layout()
Figure. Activations functions and its derivatives. Saturation regions are highlighted grey. Note that GELU has 4 regions where the others only has 3 (Tanh) or 2 (ReLU and SELU). SELU has a spike in its derivative at zero, which can cause issues with large LR.
Having outputs that range to \(+\infty\) for ReLU (adopted by SELU and GELU) increases network expressivity. On the other hand, Tanh is bounded in \([-1, 1]\) with gradients that saturate in the asymptotes. Moreover, a derivative of 1 for a large region helps to accelerate training and avoid vanishing gradients. Recall that if \(\boldsymbol{\mathsf{y}} = \varphi(\boldsymbol{\mathsf{z}})\), where \(\boldsymbol{\mathsf{z}} = \boldsymbol{\mathsf{x}}^\top\boldsymbol{\mathsf{w}}\) and \(\varphi\) is an activation, then
Here \(\varphi^{\prime}(\boldsymbol{\mathsf{z}})\) is a matrix with entries \(\varphi^{\prime}(\boldsymbol{\mathsf{z}}_{bj})\) and \(\odot\) denotes the Hadamard product.
Stacking layers effectively forms a chain of weights matrix multiplication with entries of the gradients scaled in between by derivatives of the activation function. For \(\varphi = \tanh\) the gradients are scaled down since \(\tanh^\prime(z) \leq 1\) for any \(z \in \mathbb R.\) Hence, creating a deep stack of fully-connected hidden layers can result in vanishing gradients at initialization. ReLU is interesting as it allows sparse gradient updates since \(\varphi^\prime(z) \in [0, 1].\)
def get_model(vocab_size: int,
block_size: int,
emb_size: int,
width: int,
num_linear: int,
activation: str = "relu",
layernorm: bool = False):
act_fn = activation_fns[activation]
layers = [
nn.Embedding(vocab_size, emb_size),
nn.Flatten(),
nn.Linear(block_size * emb_size, width)
]
layers.append(nn.LayerNorm(width)) if layernorm else ""
layers.append(act_fn())
for _ in range(num_linear):
layers.append(nn.Linear(width, width))
layers.append(nn.LayerNorm(width)) if layernorm else ""
layers.append(act_fn())
layers.append(nn.Linear(width, vocab_size))
layers.append(nn.LayerNorm(vocab_size)) if layernorm else ""
model = nn.Sequential(*layers)
return model
Running an experiment:
model_params = {
"vocab_size": 29,
"block_size": 3,
"emb_size": 10,
"width": 30,
"num_linear": 6
}
train_params = {
"steps": 2000,
"batch_size": 32,
"lr": 0.2
}
train_hooks = {"ln": {}, "un": {}}
trainers = {"ln": {}, "un": {}}
for act in activation_fns.keys():
for use_ln in [True, False]:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=gain[act], generator=g())
activation_stats = OutputStats()
weight_grad_stats = WeightGradientStats()
weight_update_ratio = WeightUpdateRatio(train_params["lr"])
hooks = [
{
"fwd_hooks": [],
"bwd_hooks": [weight_grad_stats, weight_update_ratio],
"layer_types": (nn.Linear, nn.Embedding)
},
{
"fwd_hooks": [activation_stats],
"bwd_hooks": [],
"layer_types": tuple(activation_fns.values())
}
]
model = init_model(get_model(activation=act, layernorm=use_ln, **model_params), init_fn)
trainer = run(model, train_dataset, **train_params, hooks=hooks)
k = "ln" if use_ln else "un"
trainers[k][act] = trainer
train_hooks[k][act] = {} if act not in train_hooks[k] else ""
train_hooks[k][act]["outs"] = activation_stats
train_hooks[k][act]["grad"] = weight_grad_stats
train_hooks[k][act]["rate"] = weight_update_ratio
SELU does not train with large LR (can diverge even at lr=0.3
) compared to other activations. Probably due to the spike in its derivative at zero. Tanh initial loss improves with LN consistent with the above discussion.
Show code cell source
for l, ln in enumerate(["un", "ln"]):
message = "train loss (normalized)" if ln == "ln" else "train loss (unnormalized)"
print(message)
for k, act in enumerate(activation_fns):
trainer = trainers[ln][act]
train_loss = trainer.logs["train"]["loss"]
start = np.array(train_loss[:100]).mean()
end = np.array(train_loss[-100:]).mean()
diff = end - start
print(f" {act} start: {start:.3f} end: {end:.3f} diff: {diff:.3f}")
print()
train loss (unnormalized)
tanh start: 2.996 end: 2.503 diff: -0.493
selu start: 2.975 end: 2.503 diff: -0.471
relu start: 3.005 end: 2.521 diff: -0.485
gelu start: 3.015 end: 2.487 diff: -0.528
train loss (normalized)
tanh start: 2.968 end: 2.479 diff: -0.490
selu start: 2.919 end: 2.484 diff: -0.435
relu start: 3.002 end: 2.559 diff: -0.442
gelu start: 2.985 end: 2.505 diff: -0.480
Histograms. Activation and gradient distributions at the last step of the hidden linear layers:
Show code cell source
step = -1
for key in ["outs", "grad"]:
fig, ax = plt.subplots(2, len(activation_fns), figsize=(12, 6))
for i, act in enumerate(activation_fns.keys()):
h = 0
for k, ln in enumerate(["un", "ln"]):
records = train_hooks[ln][act][key].records
with torch.no_grad():
for j, m in enumerate(records.keys()):
layer_types = tuple(activation_fns.values()) if key == "outs" else nn.Linear
if isinstance(m, layer_types) and 0 < j < len(records) - 1:
data = records[m][step].reshape(-1).cpu().numpy()
h = max(h, np.abs(np.percentile(data, 99)))
sns.distplot(data, bins=50, ax=ax[k, i], label=f"Linear.{j+1}", hist=False)
ax[k, i].set_xlim(-h, h)
ax[k, i].set_title(f"{act} ({'unnormalized' if ln == 'un' else 'layernorm'})")
ax[k, i].ticklabel_format(axis="x", style="sci", scilimits=(-2, 2))
ax[k, i].legend(fontsize=8)
fig.suptitle(f"{'Activation' if key == 'outs' else 'Gradient'} distribution at final step", fontsize=15)
fig.tight_layout()
Observe that Tanh outputs have a small range compared to the other activations. ReLU and GELU have tails on the positive values, while SELU is more symmetric with a heavy tail in the negative axis. LN makes the distributions more similar across layers. The plot of percentiles per layer below allow us to quantify this better:
Show code cell source
step = -1
for key in ["outs", "grad"]:
fig, ax = plt.subplots(1, len(activation_fns), figsize=(10, 3))
for k, ln in enumerate(["un", "ln"]):
alpha = 1.0 if ln == "ln" else 0.8
linestyle = "dotted" if ln == "un" else "solid"
linewidth = 2.0 if ln == "ln" else 1.5
label = "LN" if ln == "ln" else "un"
for i, act in enumerate(activation_fns.keys()):
s = []
u = []
t = []
records = train_hooks[ln][act][key].records
with torch.no_grad():
for j, m in enumerate(records.keys()):
layer_types = tuple(activation_fns.values()) if key == "outs" else nn.Linear
if isinstance(m, layer_types) and 0 < j < len(records) - 1:
data = records[m][step].reshape(-1).cpu().numpy()
s.append(np.percentile(data, 16))
u.append(np.percentile(data, 50))
t.append(np.percentile(data, 84))
ax[i].plot(np.arange(len(s)) + 2, s, color=f"C{i}", alpha=alpha, linewidth=linewidth, linestyle=linestyle, label=label)
ax[i].plot(np.arange(len(s)) + 2, t, color=f"C{i}", alpha=alpha, linewidth=linewidth, linestyle=linestyle)
ax[i].plot(np.arange(len(s)) + 2, u, color="black", alpha=alpha, linewidth=linewidth, linestyle=linestyle)
ax[i].fill_between(np.arange(len(s)) + 2, s, t, color=f"C{i}", alpha=0.2)
ax[i].set_title(act)
ax[i].set_xlabel("layer")
ax[i].set_ylabel(r"$\sigma$")
ylim = (-3, 3) if key == "outs" else (-0.05, 0.05)
ax[i].set_ylim(*ylim)
ax[i].grid(alpha=0.4, linestyle="dashed", linewidth=0.6)
ax[i].legend(fontsize=8)
ax[i].set_xticks(np.arange(len(s)) + 2)
fig.suptitle(f"{'Activation' if key == 'outs' else 'Weight gradient'} distribution at final step", fontsize=15)
fig.tight_layout()
ReLU and GELU have decreasing gradients while Tanh and SELU have increasing gradients. In general, the propagation of weight gradients seem more constant across layers with LN. Next, observe that SELU (and GELU to a lower degree) which are not symmetric have shifting median activations.
This induces bias to inputs \(\boldsymbol{\mathsf{x}}\) of a layer to all have the same sign. Recall \(\frac{\partial\mathcal{L}}{\partial {\mathsf{w}}_{ij}} = {\mathsf{x}}_i \frac{\partial\mathcal{L}}{\partial {\mathsf{z}}_j},\) so that \(\frac{\partial\mathcal{L}}{\partial {\mathsf{w}}_{ij}}\) will have the same sign for all \(i,\) which can slow down learning. Also, if you have two independent random vectors as input, the inner product of its outputs tends to become large and positive:
n = 1000
w = torch.randn(20, 20)
a = torch.randn(n, 20)
b = torch.randn(n, 20)
print(((a @ w) * (b @ w)).sum(dim=1).abs().mean())
print((((a + 0.2) @ w) * ((b + 0.2) @ w)).sum(dim=1).abs().mean())
tensor(107.3766)
tensor(110.3988)
This effect compounds with depth. Hence, deep networks affected by bias shift tend to predict the same label for all training examples at initialization. Bias shift is counteracted by LN, ensuring that the mean activation on each channel is zero across the layer.
Weight update ratio. LN improves separation and tightness of update ratio curves. SELU looks ideal (e.g. less fluctuations, tightness across hidden layers of same shape) these properties seem to be correlated with better loss. The significant improvement for the Tanh and SELU curves may explain why their performance improved with LN.
Show code cell source
from typing import Literal
def plot_weight_update_ratio(key: Literal["un", "ln"], ylim=(-3.0, -0.5), legend=False):
fig, ax = plt.subplots(1, len(activation_fns), figsize=(12, 5))
for i, act in enumerate(activation_fns):
records = train_hooks[key][act]["rate"].records
title = f"{act} (normalized)" if key == "ln" else f"{act} (unnormalized)"
for j, m in enumerate(records.keys()):
ax[i].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
ax[i].grid(linestyle="dotted", alpha=0.6)
ax[i].set_title(title, fontsize=12)
ax[i].set_ylabel(r"$\log(\lambda \| \nabla \mathbf{w} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
ax[i].set_xlabel("steps")
ax[i].set_ylim(*ylim)
if legend:
ax[-1].legend(fontsize=8, loc="lower right")
fig.tight_layout()
plot_weight_update_ratio("un", ylim=(-4.0, -1))
plot_weight_update_ratio("ln", ylim=(-4.0, -1), legend=True)
Appendix: Rank collapse#
Rank collapse happens when outputs of a layer lie in a small subspace of its output space. Xavier initialization by itself cannot prevent rank collapse. In fact, the difference between first and singular values of a product of Gaussian matrices sampled with Xavier initialization grows exponentially with depth (§3.2 of [DKB+20]). Activations help mitigate rank collapse by adding non-linearities in between weight matrices.
One way to measure rank numerically is by using singular values. Every matrix \(\boldsymbol{\mathsf{A}}\) has a singular value decomposition (SVD) \(\boldsymbol{\mathsf{A}} = \boldsymbol{\mathsf{U}} \boldsymbol{\Sigma} \boldsymbol{\mathsf{V}}^\top\) where \(\boldsymbol{\mathsf{U}}\) and \(\boldsymbol{\mathsf{V}}\) are orthogonal matrices whose columns form an orthonormal basis on the output and input spaces, respectively. And \(\boldsymbol{\Sigma}\) is a diagonal matrix containing singular values of \(\boldsymbol{\mathsf{A}}.\) It turns out that the number of nonzero singular values of a matrix is equal to its rank. In particular, the first singular value \(\sigma_1\) is equal to the operator norm which is roughly:
Geometrically, this corresponds to the width of the output ellipsoid of a unit sphere in the input space transformed by \(\boldsymbol{\mathsf{A}}.\) The other singular values correspond to lengths of the other axes of the ellipsoid. For example, if only two singular values are nonzero, then the output ellipsoid, and hence the entire output space, is embedded in two dimensions.
Show code cell source
rng = np.random.RandomState(0)
A = rng.randn(2, 2)
u, s, vT = np.linalg.svd(A)
N = 100
t = np.linspace(-1, 1, N)
unit_circle = np.stack([np.cos(2*np.pi*t), np.sin(2*np.pi*t)], axis=0)
outputs = A @ unit_circle
# Plot
plt.scatter(unit_circle[0, :], unit_circle[1, :], s=0.8, label="x")
plt.scatter(outputs[0, :], outputs[1, :], s=0.8, label="Ax")
plt.axis('scaled')
# Checking: max norm == s[0]
max_output_norm = np.linalg.norm(outputs, axis=0).max()
min_output_norm = np.linalg.norm(outputs, axis=0).min()
max_output_norm_ = np.linalg.norm(outputs, axis=0).argmax()
print("A=")
print(A)
print()
print(f'max ‖Ax‖ = {max_output_norm:.2f} (‖x‖ = 1)')
print(f'min ‖Ax‖ = {min_output_norm:.2f}')
print(f'σ₁ = {s[0]:.2f}')
print(f'σ₂ = {s[1]:.2f}')
# Plotting singular vectors as axis of ellipse
plt.arrow(0, 0, s[0]*u[0, 0], s[0]*u[1, 0], width=0.01, length_includes_head=True)
plt.arrow(0, 0, s[1]*u[0, 1], s[1]*u[1, 1], width=0.01, length_includes_head=True)
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.legend()
plt.savefig('../../img/nn/05-singular-ellipsoid.svg')
plt.close()
A=
[[1.76405235 0.40015721]
[0.97873798 2.2408932 ]]
max ‖Ax‖ = 2.75 (‖x‖ = 1)
min ‖Ax‖ = 1.29
σ₁ = 2.75
σ₂ = 1.29
Batch normalization. Below we test rank collapse with LN and batch normalization (BN) [IS15]. BN is similar to LN except that the normalization occurs in the batch dimension. As such, BN has different prediction algorithms for training and inference since we have to estimate batch statistics when making predictions with single instances at test time. This complicates using BN in practice (also see [WJ21]). However, in addition to having the same benefits as LN, BN avoids rank collapse [DKB+20] which will be shown in the following experiment.
def get_model(vocab_size: int,
block_size: int,
emb_size: int,
width: int,
num_linear: int,
activation: str = "relu",
norm: str = None):
norm_layer = {
"layernorm": nn.LayerNorm,
"batchnorm": nn.BatchNorm1d
}[norm] if norm else None
act_fn = activation_fns[activation]
layers = [
nn.Embedding(vocab_size, emb_size),
nn.Flatten(),
nn.Linear(block_size * emb_size, width)
]
layers.append(norm_layer(width)) if norm else ""
layers.append(act_fn())
for _ in range(num_linear):
layers.append(nn.Linear(width, width))
layers.append(norm_layer(width)) if norm else ""
layers.append(act_fn())
layers.append(nn.Linear(width, vocab_size))
layers.append(norm_layer(vocab_size)) if norm else ""
model = nn.Sequential(*layers)
return model
Hook for calculating singular values:
class SingularValuesRatio(HookHandler):
def hook_fn(self, module, input, output):
self.records[module] = self.records.get(module, [])
s = torch.linalg.svd(output)[1] # singular values
self.records[module].append((s[0] / s.sum()).item())
Here we are careful not to use large LR. Dying units causes rank collapse which may lead us to incorrectly conclude that LN mitigates rank collapse better than an unnormalized network (since LN prevents dying units).
# Note: act = relu. But same results for tanh and gelu.
sv_hooks = {}
message = ["\nloss (last 100 steps)"]
model_params["num_linear"] = 6
train_params["lr"] = 0.3
train_params["steps"] = 3000
for norm in ["batchnorm", "layernorm", None]:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
model = init_model(get_model(norm=norm, **model_params), init_fn)
sv_ratio = SingularValuesRatio()
hooks = [{"bwd_hooks": [], "fwd_hooks": [sv_ratio], "layer_types": (nn.Linear,)}]
trainer = run(model, train_dataset, **train_params, hooks=hooks)
sv_hooks[norm] = sv_ratio
message.append(f" {norm}: \t{sum(trainer.logs['train']['loss'][-100:]) / 100:.3f}")
print("\n".join(message))
loss (last 100 steps)
batchnorm: 2.466
layernorm: 2.492
None: 2.499
Show code cell source
fig, ax = plt.subplots(1, len(sv_hooks), figsize=(10, 4))
for i, norm in enumerate(sv_hooks):
records = sv_hooks[norm].records
for j, m in enumerate(records.keys()):
ax[i].plot(moving_avg(records[m], window=0.0005), label=m.__class__.__name__ + "." + str(j), color=f"C{j}", alpha=min(0.5 + (0.5 / len(records.keys())) * j, 1.0))
ax[i].set_ylim(0, 1)
ax[i].grid(linestyle="dashed", alpha=0.5)
ax[i].set_xlabel("steps")
ax[i].set_title(norm)
ax[0].set_ylabel(r"${\sigma}_1\; /\; \mathbf{\sigma}$.sum()")
ax[-1].set_title("Unnormalized")
ax[-1].legend(fontsize=7)
fig.tight_layout()
Remark. Degree of rank collapse increases with depth. This makes sense for linear networks since \(\text{rank}\,\boldsymbol{\mathsf{A}}\boldsymbol{\mathsf{X}} \leq \text{rank}\,\boldsymbol{\mathsf{X}}\) (see eqn. (8) in [FZH+22] for general networks). LN does not fully diminish rank collapse. In the next notebook, we will look at residual connections which help to mitigate rank collapse without having to use BN.
Appendix: Bias initialization#
Recall that for our experiments biases are initialized to zero. This makes sense for hidden layers. However, the final layer weights should be initialized correctly. For example, given an imbalanced dataset with binary label ratio 1:10, it makes sense to calibrate the bias of your logits such that your network predicts probability of \(0.1\) for the minority class at initialization. Setting the biases of the logits correctly can speed up convergence since the first few training steps is spent just learning the bias.
This technique is demonstrated below. The following dataset has a 1:7 class imbalance:
import torch
torch.manual_seed(2)
N = 1500 # inner circle
M = 6 * N # outer circle
noise = lambda n, e: torch.randn(n, 2) * e
s = 2 * torch.pi * torch.rand(N, 1)
t = 2 * torch.pi * torch.rand(M, 1)
x0 = torch.cat([0.1 * torch.cos(s), 0.1 * torch.sin(s)], dim=1) + noise(N, 0.05)
x1 = torch.cat([1.0 * torch.cos(t), 1.0 * torch.sin(t)], dim=1) + noise(M, 0.1)
y0 = (torch.ones(N,) * 0).long()
y1 = (torch.ones(M,) * 1).long()
Show code cell source
plt.scatter(x0[:, 0], x0[:, 1], s=2.0, label=0, color="C0")
plt.scatter(x1[:, 0], x1[:, 1], s=2.0, label=1, color="C1")
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.legend()
plt.axis('equal');
x = torch.cat([x0, x1])
y = torch.cat([y0, y1])
ds = torch.utils.data.TensorDataset(x, y)
train_loader = lambda: DataLoader(ds, batch_size=64, shuffle=True)
y.float().mean() # 6 / 7 = 0.8571428571428571
tensor(0.8571)
Training our baseline model:
model = lambda: nn.Sequential(
nn.Linear(2, 3), nn.ReLU(),
nn.Linear(3, 3), nn.ReLU(),
nn.Linear(3, 2)
)
RANDOM_SEED = 0
set_seed(RANDOM_SEED)
model_base = model()
optim = torch.optim.Adam(model_base.parameters(), lr=0.1)
losses_base = []
for epoch in range(2):
for xb, yb in train_loader():
logits = model_base(xb)
loss = F.cross_entropy(logits, yb)
loss.backward()
optim.step()
optim.zero_grad()
losses_base.append(loss.item())
Next, initialize our test model with careful bias initialization on logits:
set_seed(RANDOM_SEED)
model_test = model()
logits = list(model_test.modules())[-1]
logits.bias.data = torch.tensor([0.0, np.log(6)]).float() # w0 = 0, w1 s.t. exp(w1) / (exp(0) + exp(w1)) = 6 / 7
F.softmax(model_test(xb), dim=1).mean(dim=0)
tensor([0.1388, 0.8612], grad_fn=<MeanBackward1>)
Training the test model:
optim = torch.optim.Adam(model_test.parameters(), lr=0.1)
losses_test = []
for epoch in range(2):
for xb, yb in train_loader():
logits = model_test(xb)
loss = F.cross_entropy(logits, yb)
loss.backward()
optim.step()
optim.zero_grad()
losses_test.append(loss.item())
Results. Test model has faster convergence, i.e. better starting and final loss. Consistent over multiple random seeds (3 or 4).
Show code cell source
plt.plot(losses_base, label="base")
plt.plot(losses_test, label="test")
plt.legend();
print(f"base: [{losses_base[0]:.3e}, {losses_base[-1]:.3e}]")
print(f"test: [{losses_test[0]:.3e}, {losses_test[-1]:.3e}]")
base: [6.208e-01, 4.768e-07]
test: [3.772e-01, 0.000e+00]
■