Activations and Gradients#

Status Source Stars


Readings: [BKH16] [IS15]

Introduction#

Training neural networks involves computation across millions (or billions) of weights and activations. This process can be fragile. In this notebook, we attach hooks in order to deep neural nets to analyze the statistics of activations and gradients during training, and consider pitfalls when they are improperly scaled. Finally, we introduce layer normalization (LN) [BKH16] which allows stable propagation of activations and gradients across layers. This makes training deep networks so much easier (e.g. without a lot of hyperparameter tuning).

In the appendix, we consider the effects of the choice of activation function on training dynamics. Finally, we discuss rank collapse for neural network layers where the dimensionality of the output space of each layer degenerates with depth. We find empirically that batch normalization [IS15] prevents rank collapse in MLPs, while LN does not. This is consistent with [DKB+20] and [DCL21].

Preliminaries#

import math
import torch
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib_inline import backend_inline

DATASET_DIR = Path("./data").absolute()
RANDOM_SEED = 1
DEBUG = False
MATPLOTLIB_FORMAT = "png" if DEBUG else "svg"

def set_seed(s: int):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)

set_seed(RANDOM_SEED)
warnings.simplefilter(action="ignore")
backend_inline.set_matplotlib_formats(MATPLOTLIB_FORMAT)

Still using the names dataset from the previous notebook:

import os
if not os.path.isfile("./data/surnames_freq_ge_100.csv"):
    !wget -O ./data/surnames_freq_ge_100.csv https://raw.githubusercontent.com/particle1331/spanish-names-surnames/master/surnames_freq_ge_100.csv
    !wget -O ./data/surnames_freq_ge_20_le_99.csv https://raw.githubusercontent.com/particle1331/spanish-names-surnames/master/surnames_freq_ge_20_le_99.csv
else:
    print("Data files already exist.")

col = ["surname", "frequency_first", "frequency_second", "frequency_both"]
df1 = pd.read_csv(DATASET_DIR / "surnames_freq_ge_100.csv", names=col, header=0)
df2 = pd.read_csv(DATASET_DIR / "surnames_freq_ge_20_le_99.csv", names=col, header=0)
Hide code cell output
Data files already exist.
FRAC_LIMIT = 0.10 if DEBUG else 1.0
df = pd.concat([df1, df2], axis=0)[['surname']].sample(frac=FRAC_LIMIT)
df['surname'] = df['surname'].map(lambda s: s.lower())
df['surname'] = df['surname'].map(lambda s: s.replace("de la", "dela"))
df['surname'] = df['surname'].map(lambda s: s.replace(" ", "_"))

names = [n for n in df.surname.tolist() if "'" not in n and 'ç' not in n and len(n) >= 2]
df = df[['surname']].dropna().astype(str)
df = df[df.surname.isin(names)]
df.to_csv(DATASET_DIR / 'spanish_surnames.csv', index=False)
df = pd.read_csv(DATASET_DIR / 'spanish_surnames.csv').dropna()
df.head()
surname
0 agredano
1 poblador
2 girba
3 rabanales
4 yucra

Datasets#

Defining here the dataset class used in the previous notebook:

import torch
from torch.utils.data import Dataset

class CharDataset(Dataset):
    def __init__(self, contexts: list[str], targets: list[str], chars: str):
        self.chars = chars
        self.ys = targets
        self.xs = contexts
        self.block_size = len(contexts[0])
        self.itos = {i: c for i, c in enumerate(self.chars)}
        self.stoi = {c: i for i, c in self.itos.items()}

    def get_vocab_size(self):
        return len(self.chars)
    
    def __len__(self):
        return len(self.xs)
    
    def __getitem__(self, idx):
        x = self.encode(self.xs[idx])
        y = torch.tensor(self.stoi[self.ys[idx]]).long()
        return x, y 
        
    def decode(self, x: torch.tensor) -> str:
        return "".join([self.itos[c.item()] for c in x])

    def encode(self, word: str) -> torch.tensor:
        return torch.tensor([self.stoi[c] for c in word]).long()


def build_dataset(names, block_size=3):
    """Build word context -> next char target lists from names."""
    xs = []     # context list
    ys = []     # target list
    for name in names:
        context = ["."] * block_size
        for c in name + ".":
            xs.append(context)
            ys.append(c)
            context = context[1:] + [c]
    
    chars = sorted(list(set("".join(ys))))
    return CharDataset(contexts=xs, targets=ys, chars=chars)

Example dataset with block size 3:

dataset = build_dataset(names, block_size=3)
xs = []
ys = []
for i in range(7):
    x, y = dataset[i]
    xs.append(x)
    ys.append(y)

pd.DataFrame({'x': [x.tolist() for x in xs], 'y': [y.item() for y in ys], 'x_word': ["".join(dataset.decode(x)) for x in xs], 'y_char': [dataset.itos[c.item()] for c in ys]})
x y x_word y_char
0 [0, 0, 0] 2 ... a
1 [0, 0, 2] 8 ..a g
2 [0, 2, 8] 19 .ag r
3 [2, 8, 19] 6 agr e
4 [8, 19, 6] 5 gre d
5 [19, 6, 5] 2 red a
6 [6, 5, 2] 15 eda n

Ideally, we should use a stratified k-fold that ensures character distribution is the same between slips. But too lazy. Here we just partition the names dataset by index to create the validation and train datasets.

SPLIT_RATIO = 0.30
split_point = int(SPLIT_RATIO * len(names))
names_train = names[:split_point]
names_valid = names[split_point:]
train_dataset = build_dataset(names_train, block_size=3)
valid_dataset = build_dataset(names_valid, block_size=3)

len(train_dataset), len(valid_dataset)
(186339, 433815)

Weight initialization#

SGD requires choosing an arbitrary starting point \(\boldsymbol{\Theta}_{\text{init}}.\) Setting all weights to zero or some constant does not work as symmetry in the neurons of the network will make it difficult (if not impossible) to train the model. Hence, setting the weights randomly to break symmetry is a good starting point. However, this is still not enough since the variance of every neuron is additive (again due to symmetry and some assumptions):

\[{\sigma_{\boldsymbol{\mathsf{y}}}} = \sqrt{n} \cdot {\sigma_{\boldsymbol{\mathsf{w}}}} \, {\sigma_{\boldsymbol{\mathsf{x}}}}\]

where \(n = |\boldsymbol{\mathsf{x}}|.\)

import seaborn as sns

x = torch.randn(1000, 100)
w = torch.randn( 100, 200)
y = x @ w
Hide code cell source
plt.figure(figsize=(5, 3))
sns.distplot(x.reshape(-1), color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), color="C1", label='$\mathbf{y}$ (output)');
plt.legend();
../../_images/c166477ff9142d126aed8468136fe23b5b4b5f0c0b842bd4b8e640ee469154d6.svg

Observe that the output \(\boldsymbol{\mathsf{y}} = \boldsymbol{\mathsf{x}}^\top \boldsymbol{\mathsf{w}}\) for normally distributed \(\boldsymbol{\mathsf{x}}\) and \(\boldsymbol{\mathsf{w}}\) starts to spread out. This makes sense since we are adding \(n\) terms where where \(n = |\boldsymbol{\mathsf{x}}|.\) Ideally, we want \(\boldsymbol{\mathsf{y}}\) to have the same standard deviation to maintain stable activation flow. Otherwise, activations will recursively grow at each layer. This becomes increasingly problematic with depth.


Dead neurons. Activations have regions where their gradients saturate, so it is very important to control the range of preactivations. A dead neuron has vanishingly small gradients for most training examples. This means that the weights for this neuron will change very slowly compared to other neurons.

A neuron can be dead at initialization or die in the course of training. For instance, a high learning rate can result in large weights that saturate the neurons similar to the situation above. For dense layers, dead neurons tend to remain dead for a long time since weight gradients are too small to significantly change the existing weights for that neuron.

Hide code cell source
activations = {
    "tanh": torch.tanh,
    "relu": torch.relu,
}

def d(f, x):
    """Compute derivative of activation with respect to x."""
    y = f(x)
    y.backward(torch.ones_like(x))
    return x.grad

def plot_activation(act: str, ax, x):
    x.grad = None
    f = activations[act]
    y = f(x)
    df = d(f, x)
    x, y = x.detach().numpy(), y.detach().numpy()
    
    # Plotting
    ax.plot(x, y,  linewidth=2, color='red',   label="f(x)")
    ax.plot(x, df, linewidth=2, color='black', label="f'(x)")
    ax.set_title(act)
    ax.grid(linestyle="dashed")
    ax.set_ylim(-1.5, 4)
    if act == "tanh":
        ax.axvspan(-5, -1.8183, -10, 10, color='lightgray')
        ax.axvspan(1.8183, 5,   -10, 10, color='lightgray')
    elif act == "relu":
        ax.axvspan(-5, 0, -10, 10, color='lightgray', label="|f'(x)| ≤ 0.1")
        ax.legend(loc='lower right')
    

# Plotting
fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))

x = torch.linspace(-5, 5, 1000, requires_grad=True)
for i, act_name in enumerate(activations.keys()):
    plot_activation(act_name, ax[i], x)  # divmod(m, n) = m // n, m % n

fig.tight_layout()
../../_images/3117a7c62c86ccd31250b107e6b3bd1b4b75039e060ecd5f102517d42c1d1d8f.svg

Figure. Plotting activations and activation gradients. Saturation regions where derivatives are small (here set to 0.1) are highlighted gray.


Xavier init. [GB10] Since the \(\sigma_{\boldsymbol{\mathsf{y}}} = \sqrt{n} \cdot \sigma_{\boldsymbol{\mathsf{w}}} \, \sigma_{\boldsymbol{\mathsf{x}}}\), one straightforward fix is to initialize the weights \(\boldsymbol{\mathsf{w}}\) with a distribution having \({\sigma_{\boldsymbol{\mathsf{w}}}}=\frac{1}{\sqrt{n}}\) where \(n = |\boldsymbol{\mathsf{x}}|.\) Moreover, we set biases to zero. Note that setting the standard deviation turns out to be equivalent to just scaling the random variable with \(\frac{1}{\sqrt{n}}\) by linearity of expectation:

x = torch.randn(1000, 100)
w = torch.randn( 100, 200)  # Xavier normal. For xavier uniform, set w ~ U[-a, a] with a = (3 / n) ** 0.5
y = x @ (w / np.sqrt(100))
Hide code cell source
plt.figure(figsize=(5, 3))
sns.distplot(x.reshape(-1), color="C0", label='$\mathbf{x}$ (input)')
sns.distplot(y.reshape(-1), color="C1", label='$\mathbf{y}$ (output)')
plt.ylim(0, 0.6)
plt.legend();
../../_images/00dca7d401ff660b71c34a399b7a1753be0e0f7a5c934820a5f27cb0e7ae0603.svg

Remark. Proper weight init can also help with the variance of gradients. For dense layers, the backprop equations that relate input and output gradients are linear with \(\boldsymbol{\mathsf{w}}^\top.\) Improper scaling of weights can therefore cause vanishing or exploding gradients for MLPs as recursively multiplying weight matrices can increase or decrease weight gradients exponentially with depth. Hence, we can alternatively sample weights with \(\sigma_{\boldsymbol{\mathsf{w}}} = \frac{1} {\sqrt{n_\text{out}}}.\) Indeed, weight init can be implemented with fan average mode: \(\sigma_{\boldsymbol{\mathsf{w}}} = \sqrt{{2}/{{(n_\text{in} + n_\text{out})}}}.\)


Gain. Note that this scale factor only holds for linear layers. Nonlinear activations squashes its input which compounds as we stack layers in deep networks. The factor \(\mathsf{g}\) such that \({\sigma_{\boldsymbol{\mathsf{y}}}} = \mathsf{g} \cdot {\sigma_{\boldsymbol{\mathsf{x}}}}\) in called gain. This can be introduced as a factor on the standard deviation of the weights, i.e. setting the parameters of the distribution such that \(\sigma_{\boldsymbol{\mathsf{w}}} = \mathsf{g} \frac{1}{\sqrt{n}}\) where \(n = |\boldsymbol{\mathsf{x}}|\) for some \(\mathsf{g} > 0.\) Typically, weights are sampled from either normal or uniform distributions.

Hide code cell source
fig, ax = plt.subplots(1, 2, figsize=(8, 3))

n = 100
x = torch.randn(1000, n)
z = torch.tanh(x)   # Note: input of a hidden layer is an activation
w = torch.randn(n, 200) / n ** 0.5
y = z @ w

sns.distplot(x.reshape(-1), ax=ax[0], color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), ax=ax[0], color="C1", label='$\mathbf{y}$ (output)');
ax[0].set_title(r"$w \sim N(0, \sigma^2),$ $\sigma = n ^{-0.5}$", size=12)
ax[0].set_ylim(0, 0.8)

g = 5 / 3   # tanh gain
x = torch.randn(1000, n)
z = torch.tanh(x)
w = torch.randn(n, 200) * g / n ** 0.5
y = z @ w

sns.distplot(x.reshape(-1), ax=ax[1], color="C0", label='$\mathbf{x}$ (input)');
sns.distplot(y.reshape(-1), ax=ax[1], color="C1", label='$\mathbf{y}$ (output)');

ax[1].set_title(r"$w \sim N(0, \sigma^2),$ $\sigma = \frac{5}{3} {n}^{-0.5}$", size=12)
ax[1].set_ylim(0, 0.8)
ax[1].legend()
fig.tight_layout();
../../_images/8a04a9d0e811af61baddfb684743ab661fcb503c7c2f271578b02ce44f5957dc.svg

Kaiming init. Note that \(\mathsf{g}\) for an activation is usually obtained using some heuristic or by performing empirical tests, e.g. [HZRS15b] sets \(\mathsf{g} = \sqrt{2}\) since half of ReLU outputs are zero in the calculation of variance (see proof). See also calculate_gain for other commonly used activations.

PyTorch default for nn.Linear is essentially \(\mathsf{g} = \frac{1}{\sqrt{3}}\) with Xavier uniform fan-in initialization. Note that the standard deviation of the uniform distribution \(U[-a, a]\) is specified by the bound \(a\), i.e. \(\sigma = \frac{a}{\sqrt{3}}.\) The Pytorch implementation sets \(a = \frac{1}{\sqrt{n}}\) getting an effective gain of \(\frac{1}{\sqrt{3}}.\)

import torch.nn as nn

fan_in = 100
lin = nn.Linear(fan_in, 1234)
w = lin.weight.data
w.std(), 1 / (np.sqrt(3) * np.sqrt(fan_in))
(tensor(0.0576), 0.05773502691896258)

Remark. The formula for \(\sigma\) of \(U[-a, a]\) also explains the form of nn.init.kaiming_uniform_.


Logits. Note that the Xavier init heuristic also applies to the weights of the logits layer. It essentially acts as softmax temperature:

import torch.nn.functional as F

x = torch.randn(1, 30)
w = torch.randn(30, 10)
y0 = F.softmax(x @ w)
y1 = F.softmax(x @ (w / np.sqrt(30)))
Hide code cell source
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
ax[0].set_title(r"$\sigma = 1$", size=11)
ax[1].set_title(r"$\sigma = {1}/{\sqrt{30}}$", size=11)
ax[0].set_xlabel("class index")
ax[1].set_xlabel("class index")
ax[0].bar(range(10), y0[0])
ax[1].bar(range(10), y1[0]);
../../_images/fb170aa69f514c0ae39639ce5585c519510e8868f8aa016f25de474d0b7607e0.svg

Training with hooks#

In our experiments, we use total steps per run instead of epochs. The following class wraps a data loader so that the iterator resets whenever the number of steps exceed one epoch (i.e. whenever a single pass over the entire training dataset has been reached):

class InfiniteDataLoader:
    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iterator = iter(self.data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.data_iterator)
        except StopIteration:
            self.data_iterator = iter(self.data_loader)
            batch = next(self.data_iterator)
        return batch

The Trainer class below streamlines model training (see previously notebook). Note that the class is augmented with hooks which we will use to obtain activation and gradient statistics during forward and backward passes. The hooks are attached to the model before training, then removed after. We skip validation during training since our current concern is training dynamics (not generalization).

from tqdm.notebook import tqdm
from contextlib import contextmanager
from torch.utils.data import DataLoader

DEVICE = "mps"


@contextmanager
def eval_context(model):
    """Temporarily set to eval mode inside context."""
    state = model.training
    model.eval()
    try:
        yield
    finally:
        model.train(state)


class Trainer:
    def __init__(self,
                 model, optim, loss_fn, 
                 scheduler=None,
                 callbacks=[],
                 device=DEVICE, verbose=True):
        
        self.model = model
        self.optim = optim
        self.device = device
        self.loss_fn = loss_fn
        self.logs = {"train": {"loss": []}}
        self.verbose = verbose
        self.scheduler = scheduler
        self.callbacks = callbacks
        self.forward_hooks = []
        self.backward_hooks = []
    
    def __call__(self, x):
        return self.model(x.to(self.device))

    def forward(self, batch):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        return self.model(x), y

    def train_step(self, batch):
        preds, y = self.forward(batch)
        loss = self.loss_fn(preds, y)
        loss.backward()
        self.optim.step()
        self.optim.zero_grad()
        return loss

    @torch.inference_mode()
    def valid_step(self, batch):
        preds, y = self.forward(batch)
        loss = self.loss_fn(preds, y, reduction="sum")
        return loss

    def register_hooks(self,            # !!!
                       fwd_hooks=[], 
                       bwd_hooks=[], 
                       layer_types=(nn.Linear, nn.Conv2d)):
        
        for _, module in self.model.named_modules():
            if isinstance(module, layer_types):
                for f in fwd_hooks:
                    h = module.register_forward_hook(f)
                    self.forward_hooks.append(h)
                
                for g in bwd_hooks:
                    h = module.register_backward_hook(g)
                    self.backward_hooks.append(h)

    def remove_hooks(self):
        for h in self.forward_hooks + self.backward_hooks:
            h.remove()
    
    def run(self, steps, train_loader):
        train_loader = InfiniteDataLoader(train_loader)
        for _ in tqdm(range(steps)):
            # optim and lr step
            batch = next(train_loader)
            loss = self.train_step(batch)
            if self.scheduler:
                self.scheduler.step()

            # step callbacks
            for callback in self.callbacks:
                callback()

            # logs @ train step
            self.logs["train"]["loss"].append(loss.item())

    def evaluate(self, data_loader):
        with eval_context(self.model):
            valid_loss = 0.0
            for batch in data_loader:
                loss = self.valid_step(batch)
                valid_loss += loss.item()

        return {"loss": valid_loss / len(data_loader.dataset)}

Forward hooks are expected to look like:

hook(module, input, output) -> None or modified output

On the other hand, backward hook functions should look like:

hook(module, grad_in, grad_out) -> None or modified grad_in

The following classes are for handling outputs of the hook functions:

class HookHandler:
    def __init__(self):
        self.records = {}

    def hook_fn(self, module, input, output):
        raise NotImplementedError


class OutputStats(HookHandler):
    def hook_fn(self, module, input, output):
        self.records[module] = self.records.get(module, [])
        self.records[module].append(output)


class WeightGradientStats(HookHandler):
    def hook_fn(self, module, grad_in, grad_out):
        self.records[module] = self.records.get(module, [])
        for t in grad_in:
            if t is not None and t.shape == module.weight.data.shape:
                self.records[module].append(t)


class ActivationGradientStats(HookHandler):
    def hook_fn(self, module, grad_in, grad_out):
        self.records[module] = self.records.get(module, [])
        for t in grad_in:
            if t is not None:
                self.records[module].append(grad_in[0])


class DyingReluStats(HookHandler):
    def __init__(self, sat_threshold=1e-8, frac_threshold=0.95):
        super().__init__()
        self.sat_threshold = sat_threshold
        self.frac_threshold = frac_threshold

    def hook_fn(self, module, input, output):
        self.records[module] = self.records.get(module, [])
        B = output.shape[0]
        count_dead = ((output < self.sat_threshold).float().sum(dim=0) / B) > self.frac_threshold
        frac_dead = count_dead.float().mean()
        self.records[module].append(frac_dead.item())

Setting up our experiment harness:

def get_model(vocab_size: int,
              block_size: int,
              emb_size: int,
              width: int,
              num_linear: int):

    layers = [
        nn.Embedding(vocab_size, emb_size), 
        nn.Flatten(),
        nn.Linear(block_size * emb_size, width), 
        nn.ReLU()
    ] 
    
    for _ in range(num_linear):
        layers.append(nn.Linear(width, width))
        layers.append(nn.ReLU())

    layers.append(nn.Linear(width, vocab_size))
    model = nn.Sequential(*layers)
    return model
    

def init_model(model, fn):
    """Expected `fn` modifies linear layer weights in-place. See 
    https://pytorch.org/docs/stable/nn.init.html for init functions."""
    for m in model.modules():
        if isinstance(m, nn.Linear):
            fn(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        if isinstance(m, nn.Embedding):
            g = torch.Generator().manual_seed(RANDOM_SEED)
            nn.init.normal_(m.weight, generator=g)
    return model


def run(model,
        train_dataset: Dataset,
        batch_size: int,
        hooks: list[dict], 
        steps: int, 
        lr: float) -> Trainer:

    # Setup optimization and data loader
    set_seed(RANDOM_SEED)
    loss_fn = F.cross_entropy
    model = model.to(DEVICE)
    optim = torch.optim.SGD(model.parameters(), lr=lr)
    trainer = Trainer(model, optim, loss_fn, device=DEVICE)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True)

    # Finally, training...
    try:
        for hook_dict in hooks:
            trainer.register_hooks(
                fwd_hooks=[h.hook_fn for h in hook_dict["fwd_hooks"]], 
                bwd_hooks=[h.hook_fn for h in hook_dict["bwd_hooks"]], 
                layer_types=hook_dict["layer_types"]
            )

        trainer.run(steps=steps, train_loader=train_loader)

    except Exception as e:
        print(e)
    finally:
        trainer.remove_hooks()

    return trainer

The task and model used here is from the previous notebook. In the init_model function, we iterate over the layers and apply a function from the nn.init library to modify layer weights in-place. The xavier_uniform_ implementation uses fan average mode. This is applied to linear layer weights, while biases are set to zero. Finally, the embedding layer is initialized with weights from a Gaussian distribution.

def init_hooks() -> list[dict]:
    global dead_relu_stats
    global relu_outs_stats
    global relu_grad_stats
    global linear_outs_stats
    global linear_grad_stats

    dead_relu_stats = DyingReluStats(sat_threshold=1e-8, frac_threshold=0.95)
    relu_outs_stats = OutputStats()
    relu_grad_stats = ActivationGradientStats()
    linear_outs_stats = OutputStats()
    linear_grad_stats = WeightGradientStats()
    
    return [
        {
            "fwd_hooks": [linear_outs_stats], 
            "bwd_hooks": [linear_grad_stats], 
            "layer_types": (nn.Linear,)
        },
        {
            "fwd_hooks": [dead_relu_stats, relu_outs_stats], 
            "bwd_hooks": [relu_grad_stats], 
            "layer_types": (nn.ReLU,)
        }
    ]
    

model_params = {
    "vocab_size": 29,
    "block_size": 3,
    "emb_size": 10,
    "width": 30,
    "num_linear": 2
}

train_params = {
    "steps": 1000,
    "batch_size": 32,
    "lr": 0.1
}

g = lambda: torch.Generator().manual_seed(RANDOM_SEED)
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
model = init_model(get_model(**model_params), init_fn)
hooks = init_hooks()
trainer = run(model, train_dataset, **train_params, hooks=hooks)
Hide code cell source
def moving_avg(a, window=0.05):
    ma = []
    steps = len(a)
    w = int(window * steps)
    for s in range(steps):
        ma.append(np.mean(a[-w + s + 1: s + 1]))
    return ma

def plot_training_loss(trainer, window=0.05, figsize=(5, 3)):
    train_loss = trainer.logs["train"]["loss"]
    train_loss_ma = moving_avg(train_loss, window)

    plt.figure(figsize=figsize)
    plt.plot(trainer.logs["train"]["loss"], alpha=0.5, label="train")
    plt.plot(train_loss_ma, color="C0", label="train (MA)")
    plt.ylabel("loss")
    plt.xlabel("steps")
    plt.ylim(min(trainer.logs["train"]["loss"]) - 0.2, min(max(trainer.logs["train"]["loss"]) + 0.2, 5.0))
    plt.grid(linestyle="dotted", alpha=0.6)
    plt.legend();
plot_training_loss(trainer)
../../_images/e95e8bf609d0b320cefca0fe692e09f4fbe0d743cdae2feedc51ffdd084f8698.svg

Histograms#

The hooks stored activation and gradient statistics during training which we now visualize:

Hide code cell source
def plot_training_histograms(stats: list[HookHandler], 
                             bins: int = 100, 
                             num_stds: int = 4, 
                             cmaps: list[str] = None, 
                             labels: list[str] = None,
                             figsize=(6, 2), 
                             aspect: int = 5):
    
    rows = len(stats)
    cols = len(stats[0].records.keys())
    cmaps = ["viridis"] * rows if cmaps is None else cmaps
    H, W = figsize
    fig, ax = plt.subplots(rows, cols, figsize=(H * rows, W * cols))

    # estimate a good range per row (i.e. per handler)
    # same range => show vanishing / exploding over layers
    h = {}
    acts = {}
    for i in range(rows):
        for j, m in enumerate(stats[0].records.keys()):
            if m not in stats[i].records:
                continue
            acts[i, j] = torch.stack([t.reshape(-1) for t in stats[i].records[m]], dim=0).cpu()
        
        all_data = torch.concat([acts[i, j].reshape(-1) for j in range(len(stats[0].records))])
        h[i] = all_data.std().item() * num_stds

    # calculate histograms
    for i in range(rows):
        for j, m in enumerate(stats[0].records.keys()):    
            if m not in stats[i].records:
                continue

            with torch.no_grad():
                # Add one count -> log = more colorful, note: log 1 = 0.
                # total count (i.e. sum per col) == batch_size * num_neurons, .: normalized
                hists = [(torch.histc(acts[i, j][t], bins=bins, min=-h[i], max=h[i]) + 1).log().cpu().numpy().reshape(1, -1) for t in range(acts[i, j].shape[0])]

            # histogram image
            axf = lambda i, j: ax[j] if rows == 1 else ax[i, j]
            axf(i, j).set_title(m.__class__.__name__ + "." + str(j), size=10)
            axf(i, j).imshow(np.flip(np.concatenate(hists, axis=0).T, axis=0), aspect='auto', cmap=cmaps[i])     # (!)
            axf(i, j).set_aspect(aspect)                                                                 # transpose => positive vals = down 
            axf(i, j).set_yticks([])                                                                     # flip => positive vals = up
            if j == 0 and labels is not None:
                axf(i, j).set_ylabel(f"{labels[i]}\n[-{h[i]:.1e}, {h[i]:.1e}]")
            else:                                                           
                axf(i, j).set_ylabel(f"[-{h[i]:.1e}, {h[i]:.1e}]")
            axf(i, j).set_xlabel("steps")
    
    fig.tight_layout()
    plt.show()
Hide code cell source
def plot_training_stats(stats: list[HookHandler], 
                        labels: list[str] = None, 
                        colors: list[str] = None, 
                        figsize=(6, 2)):
    
    rows = len(stats)
    cols = len(stats[0].records.keys())
    colors = [f"C{i}" for i in range(rows)] if colors is None else colors
    H, W = figsize
    fig, ax = plt.subplots(rows, cols, figsize=(H * rows, W * cols))

    # calculate mean and std
    axf = lambda i, j: ax[j] if rows == 1 else ax[i, j]
    for i in range(rows):
        h = 0  # y lim
        for j, m in enumerate(stats[0].records.keys()):    
            if m not in stats[i].records:
                continue

            with torch.no_grad():
                r = torch.stack([t for t in stats[i].records[m]])
                r = r.view(r.shape[0], -1).cpu().numpy()
                u = np.percentile(r, q=50, axis=1)  # [a, b] contains 68% of data
                a = np.percentile(r, q=16, axis=1)  # 50 - 34 = 16
                b = np.percentile(r, q=84, axis=1)  # 50 + 34 = 84

            h = max(h, max(np.abs(a).mean(), np.abs(b).mean()).item())

            # plot
            axf(i, j).set_title(m.__class__.__name__ + "." + str(j), size=10)
            axf(i, j).plot(a, color=colors[i])
            axf(i, j).plot(b, color=colors[i])
            axf(i, j).fill_between(np.arange(len(u)), a, b, color=colors[i], alpha=0.3)
            axf(i, j).plot(u, color="black", label=f"$\mu$", alpha=0.8)
            
            axf(i, j).set_xlabel("steps")
            axf(i, j).grid(linestyle="dashed", alpha=0.6)
            if j == 0 and labels is not None:
                axf(i, j).set_ylabel(labels[i])

        # set y-lims
        for j, m in enumerate(stats[0].records.keys()):    
            if m in stats[i].records:
                axf(i, j).set_ylim(-2 * h, 2 * h)
                
    fig.tight_layout()
    plt.show()
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.5))
loss (last 100 steps): 2.557870314121246
../../_images/1bec1d406bbcfe9e1830dd45415250f82e00670dd6e3a711e4e389005d52000d.svg../../_images/399b572f87d8d8ad259778e4dea28a5087eced582066ac0678b442c54a9e49cf.svg

Figure. First plot shows activation and gradient histograms over time. Range of values shown is \(\mu \pm 4 \sigma\) over all time steps, units, and layers. Considering each layer separately, the network is training well. But between layers we see that the gradient distribution changes. Since the distributions are not symmetric, the second plot shows the 16-50-84th percentiles of samples at each step.


# can also check relu stats
plot_training_histograms([relu_outs_stats, relu_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(5.0, 1.5))
plot_training_stats([relu_outs_stats, relu_grad_stats], labels=["Outputs", "Gradients"], figsize=(5, 1.5))
Hide code cell output
../../_images/906a6b8047fe5b5fa3efcd8dd9769396a08cc1b92c5c86657c88278be8e33882.svg../../_images/17cf86b68076f36317ba764a7c14d938f45c59fb65ff9da8d6245edb8dbab33a.svg

Dying units#

Good initialization prevent dying units:

Hide code cell source
def plot_dying_relus(trainer, stats):
    module_list = [m for m in trainer.model.modules() if isinstance(m, nn.ReLU)]
    plt.figure(figsize=(6, 3))
    for i, m in enumerate(module_list):
        if isinstance(m, nn.ReLU):
            moving_avg = []
            steps = len(stats.records[m])
            w = int(0.05 * steps)
            for s in range(steps):
                moving_avg.append(np.mean(stats.records[m][-w + s + 1: s + 1]))

            plt.plot(moving_avg, label=f"ReLU.{i}")
        
    plt.legend()
    plt.grid(linestyle="dashed", alpha=0.6)
    plt.title(f"sat_threshold={stats.sat_threshold:.1e}, frac_treshold={stats.frac_threshold:.2f}", size=10)
    plt.xlabel("step")
    plt.ylabel("frac dead relu (MA)")
    plt.ylim(-0.1, 1.1)
print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.557870314121246
../../_images/d4589aa1062445906298b310feb40b90877036dae2bf2ced69b39640fea87565.svg

Increasing LR from 0.1 to 2.0. As discussed above, this can push weights to large values where the activations saturate, knocking off the model to a flat region in the loss surface where any sample of the data does not affect the shape of the loss surface. Hence, the model barely trains with further SGD steps.

from copy import copy

def _update(params, **kwargs) -> dict:
    """Helper to make stateless updates to params."""
    params = copy(params)
    for key, val in kwargs.items():
        params[key] = val
    return params


init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
hooks = init_hooks()
model = init_model(get_model(**model_params), init_fn)
trainer = run(model, train_dataset, **_update(train_params, lr=2.0), hooks=hooks)

print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.813358087539673
../../_images/d14e62a7dfe95b39393030e962995420f5a4bff56c7293bfce8d795c142f3245.svg
plot_training_loss(trainer)
Hide code cell output
../../_images/2461d11155ee93bf124dca750948a2e6d55d1f198b1dcb9192db9acd3dc31d19.svg
plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Hide code cell output
../../_images/d08367b4446adfddbf8a469cb9ec291b00c0c2d18133ba594e3cda6809454849.svg
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.2))
../../_images/4f5543c227fbbb504a4f217823c056b954b0981a29f183c29e0bf13de7be0cf0.svg

Layer normalization#

Recall that for weight initialization we have to specify the gain for each activation. It would be nice to learn the gain instead of setting it heuristically, or deriving it from equations. Moreover, from the above pathological examples, we find cases where initialization cannot help past the early steps (e.g. when we have large learning rates). To get stable activations during training, we want to normalize layer preactivations dynamically.

This can be done using normalization layers. In particular, we look at layer normalization [BKH16]. The following equations define the action of LayerNorm on an input \(\boldsymbol{\mathsf{z}}\) for one instance fed into the network:

\[\begin{split} \begin{aligned} {\mu} &= \frac{1}{n}\sum_{i} {\mathsf{z}}_{i} \\ {\sigma}^2 &= \frac{1}{n}\sum_{i} ({\mathsf{z}}_{i} - {\mu})^2 \\ \hat{{\mathsf{z}}}_{j} &= \frac{{\mathsf{z}}_{j} - {\mu}}{\sqrt{{\sigma}^2 + \epsilon}} \\ {{\mathsf{y}}}_{j} &= {\gamma}_j\, \hat{{\mathsf{z}}}_{j} + {\beta}_j \end{aligned} \end{split}\]

where \(n = |\boldsymbol{\mathsf{z}}|.\) This makes the outputs of each neuron lie in the same range. Consequently, this also helps control the magnitude of weights and weight gradients. Here we introduce trainable affine parameters \(\boldsymbol{\gamma}\) and \(\boldsymbol{\beta}\) that are applied per unit. This allows the output distributions to scale and shift, otherwise network expressiveness is degraded with preactivations mostly in \([-1, 1]\). It can be easily shown that LN is invariant to rescaling and recentering of the weight and data matrices.


n = 5
layernorm = nn.LayerNorm(n)

x = torch.randn(3, n)
u = x.mean(dim=1, keepdim=True)
v = x.var(dim=1, unbiased=False, keepdim=True)
eps = 0.00001
gamma = torch.ones(n)   # params init
beta = torch.zeros(n)

gamma * (x - u) / torch.sqrt(v + eps) + beta
tensor([[ 0.4980,  1.2870, -0.2261,  0.1768, -1.7358],
        [ 1.5455, -1.1316, -0.2491, -0.8746,  0.7098],
        [-0.0341, -1.5814,  1.2074,  0.8957, -0.4876]])
layernorm(x)
tensor([[ 0.4980,  1.2870, -0.2261,  0.1768, -1.7358],
        [ 1.5455, -1.1316, -0.2491, -0.8746,  0.7098],
        [-0.0341, -1.5814,  1.2074,  0.8957, -0.4876]],
       grad_fn=<NativeLayerNormBackward0>)

Remark. The affine parameters have the same shape as normalized_shape (i.e. main argument of LayerNorm). This is the as the last D dimensions of the input tensor that are normalized over, where D is the rank of normalized_shape.

Immediate effects#

Updating our architecture with LN layers which are added before the activation function. Hence, preactivations are normalized around the active region of the activation, followed by a unit-wise shift and scale. Note that we also add normalization on logits:

def get_model(vocab_size: int,
              block_size: int,
              emb_size: int,
              width: int,
              num_linear: int,
              layernorm: bool = False):

    layers = [
        nn.Embedding(vocab_size, emb_size), 
        nn.Flatten(),
        nn.Linear(block_size * emb_size, width)
    ]
    layers.append(nn.LayerNorm(width)) if layernorm else ""
    layers.append(nn.ReLU())
    
    for _ in range(num_linear):
        layers.append(nn.Linear(width, width))
        layers.append(nn.LayerNorm(width)) if layernorm else ""
        layers.append(nn.ReLU())

    layers.append(nn.Linear(width, vocab_size))
    layers.append(nn.LayerNorm(vocab_size)) if layernorm else ""
    model = nn.Sequential(*layers)
    return model

Initialization. Having layer norm makes training less sensitive to weight initialization:

losses = {0.30: {}, 1.41: {}, 3.0: {}}  # sqrt(2) = 1.414...

for gain in [0.30, 1.41, 3.0]:
    init_fn = lambda w: nn.init.xavier_uniform_(w, gain=gain, generator=g())
    model_ln = init_model(get_model(**model_params, layernorm=True), init_fn)
    model_nf = init_model(get_model(**model_params), init_fn)
    
    trainer_ln = run(model_ln, train_dataset, hooks=[], **train_params)   # no hooks
    trainer_nf = run(model_nf, train_dataset, hooks=[], **train_params)   # NF = norm free
    losses[gain]["ln"] = trainer_ln.logs["train"]["loss"]
    losses[gain]["nf"] = trainer_nf.logs["train"]["loss"]

Observe that networks with LN layers start with better loss values:

Hide code cell source
plt.figure(figsize=(6, 4))
window = 0.03

for i, gain in enumerate(losses.keys()):
    train_loss = losses[gain]["nf"]
    plt.plot(moving_avg(train_loss, window), color=f"C{i}", label=f"gain={gain:.2f}", alpha=0.6)

for i, gain in enumerate(losses.keys()):
    train_loss = losses[gain]["ln"]
    plt.plot(moving_avg(train_loss, window), color=f"C{i}", label=f"gain={gain:.2f} (LN)", linewidth=2.0)

plt.ylabel("loss")
plt.xlabel("steps")
plt.grid(alpha=0.5, linestyle="dotted")
plt.ylim(2.2, 4.5)
plt.legend();
../../_images/20520f3e16610992ed91cffd095c081b3fc782aefeca5744f92d0e75f91415ed.svg

LN can be thought of as automatically and dynamically finding per unit gain:

# Using suboptimal PyTorch default gain:
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=1/np.sqrt(3), generator=g())
model = init_model(get_model(layernorm=True, **model_params), init_fn)
hooks = init_hooks()
trainer = run(model, train_dataset, **train_params, hooks=hooks)

plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Hide code cell output
../../_images/c2b4af2212f051137b94351a97e96c156ee3ccd95196500426f28102ad12b8c8.svg

Note that shrinking activation is immediately corrected (more gradual without LN):

print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.5))
loss (last 100 steps): 2.5370990657806396
../../_images/3aaf7a452508d14b662d712ae3dced46e9bbc8c73ca2de383c429ddb3500e8a3.svg

Large learning rate. LN allows the network to tolerate larger weight updates:

init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
hooks = init_hooks()
model = init_model(get_model(layernorm=True, **model_params), init_fn)
trainer = run(model, train_dataset, **_update(train_params, lr=2.0), hooks=hooks)

print("loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
plot_dying_relus(trainer, dead_relu_stats)
loss (last 100 steps): 2.576110579967499
../../_images/28d7b8b88cd814f9d05fbfc79163340975c971fcfe8c0785c8d35d0715cd588f.svg

Looks better behaved compared to the previous plot (unnormalized with lr=2.0):

plot_training_histograms([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], cmaps=["viridis", "inferno"], figsize=(6, 1.2))
Hide code cell output
../../_images/860ac5aad970a3c201e6d3404dc20347aed836073ab1fb17c244b861081cffe9.svg
plot_training_stats([linear_outs_stats, linear_grad_stats], labels=["Outputs", "Gradients"], figsize=(6, 1.2))
../../_images/bb17d6c25f389e82c58378d3afb11871a0e889d6b4c7e1e64e8cfd23487540eb.svg

Gradient normalization#

If we look at the above plot for the gradient distribution over time, LN seems to help not only with forward propagation (i.e. the model still making meaningful predictions despite large weights), but also with backward propagation since the gradient distributions look better compared to the unnormalized network. This makes sense, since the preactivation neurons \(\boldsymbol{\mathsf{z}}\) are now interacting (Fig. 48), these neurons and consequently their weights are then forced to co-adapt. The exact dependency can be seen in the BP equations for LN which we derive below.

../../_images/05-layer-norm-compgraph.svg

Fig. 48 Computational graph of layer normalization. Dependencies are indicated by arrows.#

Let \(\boldsymbol{\mathsf{y}} = \text{LN}_{\gamma, \beta}(\boldsymbol{\mathsf{z}}).\) Parameters are straightforward:

\[ \frac{\partial\mathcal{L}}{\partial \gamma_j} = \frac{\partial \mathcal L}{\partial {\mathsf{y}}_{j}} \hat{{\mathsf{z}}}_j \quad \text{and} \quad \frac{\partial\mathcal{L}}{\partial \beta_j} = \frac{\partial \mathcal L}{\partial {\mathsf{y}}_{j}}. \]

Next, we calculate the layer stats:

\[\begin{split} \begin{aligned} \frac{\partial \mathcal L}{\partial \sigma^2} &= \sum_i\frac{\partial \mathcal L}{\partial {\mathsf{y}}_i} \frac{\partial {\mathsf{y}}_i}{\partial \sigma^2} \\ &= \sum_i \frac{\partial \mathcal L}{\partial {\mathsf{y}}_i} \gamma_i\left({\mathsf{z}}_i-\mu\right) \cdot-\frac{1}{2}\left(\sigma^2+\epsilon\right)^{-\frac{3}{2}} \\ \\ \frac{\partial \mathcal{L}}{\partial \mu}&=\frac{\partial \mathcal{L}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial \mu}+\sum_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_i} \frac{\partial {\mathsf{y}}_i}{\partial \mu} \\ \\ &=\frac{\partial \mathcal{L}}{\partial \sigma^2} \frac{2}{n} \underbrace{\sum_i \left({\mathsf{z}}_i-\mu\right)}_{0} \cdot-1 + \sum_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_i} \frac{\partial {\mathsf{y}}_i}{\partial \mu} \\ &= \sum_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_i} \frac{\partial {\mathsf{y}}_i}{\partial \mu} \\ &= \sum_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_i} \frac{\partial}{\partial \mu}\left[\frac{\gamma_i({\mathsf{z}}_i-\mu)}{\sqrt{\sigma^2+\epsilon}}+\beta_i\right] = \sum_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_i} \frac{-\gamma_i}{\sqrt{\sigma^2+\epsilon}} \end{aligned} \end{split}\]

Pushing these gradients to the input node:

\[\begin{split} \begin{aligned} \frac{\partial \mathcal{L}}{\partial {\mathsf{z}}_j} &=\frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_j} \frac{\partial {\mathsf{y}}_j}{\partial {\mathsf{z}}_j}+\frac{\partial \mathcal{L}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial {\mathsf{z}}_j}+\frac{\partial \mathcal{L}}{\partial \mu} \frac{\partial \mu}{\partial {\mathsf{z}}_j} \\ &= \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_j} \frac{\gamma_j}{\sqrt{\sigma^2+\varepsilon}}+\frac{\partial \mathcal{L}}{\partial \sigma^2} \frac{2}{n}\left({\mathsf{z}}_j-\mu\right)+\frac{\partial \mathcal{L}}{\partial \mu} \frac{1}{n} \\ &= \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_j} \frac{\gamma_j}{\sqrt{\sigma^2+\varepsilon}} + \frac{1}{n}\frac{{\mathsf{z}}_j-\mu}{\sqrt{\sigma^2+\varepsilon}} \sum_{i} \frac{\partial \mathcal L}{\partial {\mathsf{y}}_{i}} \frac{-\gamma_i}{\sqrt{\sigma^2+\varepsilon}}\frac{{\mathsf{z}}_{i}-\mu}{\sqrt{\sigma^2+\varepsilon} } +\frac{1}{n} \sum_{i} \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_{i}} \frac{-\gamma_i}{\sqrt{\sigma^2+\varepsilon}} \\ &= \frac{1}{\sqrt{\sigma^2+\varepsilon}} \left( \gamma_j\frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_j} - \frac{1}{n}\hat{{\mathsf{z}}}_j \sum_{i} \gamma_i \frac{\partial \mathcal L}{\partial {\mathsf{y}}_{i}} \hat{{\mathsf{z}}}_{i} - \frac{1}{n} \sum_{i} \gamma_i \frac{\partial \mathcal{L}}{\partial {\mathsf{y}}_{i}} \right). \end{aligned} \end{split}\]

This centers the output gradients and the middle term removes the component of the output gradient along the standardized preactivation \(\hat{\boldsymbol{\mathsf{z}}}\) in the layer dimension. Observe that \(\sum_{j} \frac{\partial \mathcal{L}}{\partial {\mathsf{z}}_j} = 0\) so that the mean of downstream gradients is zero. Moreover, it can be shown (see [XSZ+19]) that:

\[\text{var}\left(\frac{\partial \mathcal{L}}{\partial {\mathsf{z}_j}}\right) \leq \frac{1}{\sigma^2}\text{var}\left(\gamma_j \frac{\partial \mathcal{L}}{\partial {\mathsf{y}_j}}\right).\]

A larger variance of preactivations result in a reduction in the variance of gradients, and vice-versa. This centering and rescaling of gradients can be thought of as gradient normalization. This extends to weight gradients since \(\frac{\partial\mathcal{L}}{\partial \boldsymbol{\mathsf{w}}} = \boldsymbol{\mathsf{x}}^\top \frac{\partial\mathcal{L}}{\partial \boldsymbol{\mathsf{z}}}\) where both factors are centered with controlled variance.


Gradient checking. Verifying the formulas empirically using autograd. Here we asumme x is some output of a previous activation layer. Then, z is the current preactivation which we pass through LN to get y. The outputs y are passed as logits to calculate cross-entropy, with respect to some random labels, which we use to backpropagate.

n = 128
m = 256

# Forward pass graph
B = 10000
x = torch.tanh(torch.randn((B, m), requires_grad=True))
w = 3 * torch.randn((m, n), requires_grad=True)
z = x @ w

# Nudge params to make sure 1 and 0 do not hide bugs
ε = 1e-5
γ = torch.randn(n).requires_grad_(True)
β = torch.randn(n).requires_grad_(True)
v = z.var(1, keepdim=True, unbiased=False)  # (!) Hours were spent finding an error in the calculation. Forgot default: unbiased=True
μ = z.mean(1, keepdim=True) 
 = (z - μ) / torch.sqrt(v + ε)
y = γ *  + β

for u in [z, y, x, w, γ, β, v, μ]:
    u.retain_grad()

t = torch.randint(0, n, size=(B,))
loss = -torch.log(F.softmax(y, dim=1)[range(B), t]).sum() / B
loss.backward()

Computing gradients by hand:

dy = y.grad
 = (dy * ).sum(0)
 = dy.sum(0)
dv = -(γ * dy * (z - μ)).sum(1, keepdim=True) * 0.5 * (v + ε) ** -1.5
 = -(γ * dy).sum(1, keepdim=True) / (v + ε) ** 0.5
dz = γ * dy - (1 / n) *  * (γ * dy * ).sum(1, keepdim=True) - (1 / n) * (γ * dy).sum(1, keepdim=True)
dz *= (1 / torch.sqrt(v + ε))
dx = dz @ w.T
dw = x.T @ dz

Gradient check:

Hide code cell source
def compare(name, dt, t):
    exact  = torch.all(dt == t.grad).item()
    maxdiff = (dt - t.grad).abs().max().item()
    approx = torch.allclose(dt, t.grad, rtol=1e-7)
    print(f'{name} | exact: {str(exact):5s} | approx: {str(approx):5s} | maxdiff: {maxdiff:.2e}')

compare('y', dy, y)
compare('γ', , γ)
compare('β', , β)
compare('v', dv, v)
compare('μ', , μ)
compare('z', dz, z)
compare('w', dw, w)
compare('x', dx, x)
y | exact: True  | approx: True  | maxdiff: 0.00e+00
γ | exact: True  | approx: True  | maxdiff: 0.00e+00
β | exact: True  | approx: True  | maxdiff: 0.00e+00
v | exact: False | approx: True  | maxdiff: 8.53e-14
μ | exact: False | approx: True  | maxdiff: 3.64e-12
z | exact: False | approx: True  | maxdiff: 1.82e-12
w | exact: False | approx: True  | maxdiff: 2.04e-10
x | exact: False | approx: True  | maxdiff: 2.18e-11

Demonstrating the variance relation. Note the factor \(\gamma_j\) is obtained by dimensional analysis since the proof in [XSZ+19] sets \(\gamma_j = 1\) for convenience. Moreover, \(\gamma_j \frac{\partial \mathcal L}{\partial {\mathsf{y}}_j}\) seems to always appear together in the equations. But this seems to be correct:

(dz.var(1) <= 1 / (v.view(-1) + ε) * (γ * dy).var(1)).float().mean().item()
0.9995999932289124

Observe that the bound is super tight:

gap = 1 / (v.view(-1) + ε) * (γ * dy).var(1) - dz.var(1)
gap.min().mean().item(), gap.max().mean().item()
(-2.710505431213761e-20, 5.109291895816215e-14)

Remark. The above code uses a lot of broadcasting. The rule is that the leftmost dimension of the lower rank tensor is padded with 1 whenever the ranks of the tensors are unequal. If the shape does not match between two equal ranked tensors, the dimension with 1 is stretched to match the other shape. If any two non-unit dimensions are unequal, then an error is raised.

For example, z - z.mean(1) fails. The shapes are [B, n] and [B,], respectively. The latter becomes [1, B] after the first rule, then [B, B]. Another one that gives a bug (i.e. a mistake) is (1 / v) * * dy).var(1) since [B, 1] and [B,] respectively, so that the latter becomes [1, B]. Hence, the final shape of the two tensors is [B, B]! On the other hand, γ * dy works since [n,] and [B, n]. The former becomes [1, n] which becomes [B, n].

Weight update ratio#

Another thing we can look at is the magnitude of weight updates relative to the weights. That is, the ratio \(\zeta_k = \frac{\|\lambda {\nabla}{\boldsymbol{\mathsf{w}}_k}\|}{\|\boldsymbol{\mathsf{w}}_k\|}\) where \({\nabla}{\boldsymbol{\mathsf{w}}_k} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{w}}_k}.\) A low overall \(\zeta_k\) means that convergence is slow, whereas a high \(\zeta_k\) allows the network to quickly unlearn good weight configurations.

Having layers with varying \(\zeta_k\) is not good for training as weights learned by slower layers can become outdated as faster layers learn new weights, unless the layers are doing different things (e.g. embeddings and logits). This can be seen in the plot below.

class WeightUpdateRatio(HookHandler):
    def __init__(self, lr: float):
        super().__init__()
        self.lr = lr

    def hook_fn(self, module, grad_in, grad_out):
        self.records[module] = self.records.get(module, [])
        for g in grad_in:
            if g is not None and g.shape == module.weight.data.shape:
                w = module.weight.data
                g_norm = torch.linalg.norm(g)
                w_norm = torch.linalg.norm(w)
                self.records[module].append(self.lr * (g_norm / w_norm).item())

Making the network deeper:

num_linear = 4
lr = 0.6
steps = 2000
init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
weight_update_ratios = {}

for ln in ["un", "ln"]:
    use_ln = ln == "ln"
    model = init_model(get_model(layernorm=use_ln, **_update(model_params, num_linear=num_linear)), init_fn)
    weight_update_ratio = WeightUpdateRatio(lr)
    hooks = [{"fwd_hooks": [], "bwd_hooks": [weight_update_ratio], "layer_types": (nn.Linear, nn.Embedding)}]
    trainer = run(model, train_dataset, **_update(train_params, lr=lr, steps=steps), hooks=hooks)
    print(f"[{ln}] loss (last 100 steps):", sum(trainer.logs["train"]["loss"][-100:]) / 100)
    weight_update_ratios[ln] = weight_update_ratio
[un] loss (last 100 steps): 2.540506160259247
[ln] loss (last 100 steps): 2.51398802280426
Hide code cell source
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

records = weight_update_ratios["un"].records
for j, m in enumerate(records.keys()):
    ax[0].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
    ax[0].grid(linestyle="dotted", alpha=0.6)
    ax[0].set_title("Unnormalized", fontsize=12)
    ax[0].set_ylabel(r"$\log(\lambda \| \nabla \mathsf{\mathbf{w}} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
    ax[0].set_xlabel("steps")
    ax[0].set_ylim(-4, -1.0);

records = weight_update_ratios["ln"].records
for j, m in enumerate(records.keys()):
    ax[1].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
    ax[1].grid(linestyle="dotted", alpha=0.6)
    ax[1].set_title("Normalized (linear + LN)", fontsize=12)
    ax[1].set_xlabel("steps")
    ax[1].set_ylabel(r"$\log(\lambda \| \nabla \mathsf{\mathbf{w}} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
    ax[1].legend()
    ax[1].set_ylim(-4, -1.0);
../../_images/c203196b77b1be9a86410ab70155f75dcfcb2a0f2760929214c1c3b3c6c557cb.svg

Figure. Model training with +4 hidden layers, Kaiming init, and learning rate of 1. Since we took the logarithm, y=-2 correspond to an update ratio of 0.01 for one mini-batch. This roughly means that it would take 100 steps to completely erase the current weights. The update ratios for the logits and embedding layers are nicely disambiguated: the logits weights train faster while the embeddings train slower. Moreover, notice the the update rates for the hidden layers are well ordered.

Appendix: Activations#

gain = {
    "tanh": 5/3,
    "relu": math.sqrt(2),
    "selu": 1,                  # https://github.com/pytorch/pytorch/pull/53694#issuecomment-795683782
    "gelu": math.sqrt(2),       # GELU plot looks like ReLU
}

# Feel free to add to this list
activation_fns = {
    "tanh": nn.Tanh,
    "selu": nn.SELU,
    "relu": nn.ReLU,
    "gelu": nn.GELU
}

def d(act, x):
    """Compute derivative of activation with respect to x."""
    f = act()
    y = f(x)
    y.backward(torch.ones_like(x))
    return x.grad

# Region x s.t. |f'(x)| <= 0.1:
saturation_fns = {
    "tanh": lambda x: x.abs() > 1.8183,
    "selu": lambda x: x < -2.8668,
    "relu": lambda x: x < 0.0,
    "gelu": lambda x: (x < -1.8611) | ((-1.0775 < x) & (x < -0.5534)),
}
Hide code cell source
def plot_activation(act: str, ax, x):
    x.grad = None
    f = activation_fns[act]
    y = f()(x)
    df = d(f, x)
    x, y = x.detach().numpy(), y.detach().numpy()
    
    # Plotting
    ax.plot(x, y,  linewidth=2, color="red",   label="f(x)")
    ax.plot(x, df, linewidth=2, color="black", label="f'(x)")
    ax.set_title(act)
    ax.grid()
    ax.set_ylim(-2, 4)
    if act == "tanh":
        ax.axvspan(-5, -1.8183, -10, 10, color="lightgray")
        ax.axvspan(1.8183, 5,   -10, 10, color="lightgray")
    elif act == "relu":
        ax.axvspan(-5, 0, -10, 10, color="lightgray")
    elif act == "leaky_relu":
        ax.axvspan(-5, 0, -10, 10, color="lightgray")
    elif act == "elu":
        ax.axvspan(-5, -2.3025, -10, 10, color="lightgray")
    elif act == "selu":
        ax.axvspan(-5,      -2.8668, -10, 10, color="lightgray")
    elif act == "gelu":
        ax.axvspan(-5,      -1.8611, -10, 10, color="lightgray")
        ax.axvspan(-1.0775, -0.5534, -10, 10, color="lightgray", label="|f'(x)| ≤ 0.1")
        ax.legend(loc="lower right")


# Plotting
rows = math.ceil(len(activation_fns) / 2.0)
fig, ax = plt.subplots(rows, 2, figsize=(10, rows*3.5))

x = torch.linspace(-5, 5, 1000, requires_grad=True)
for i, act_name in enumerate(activation_fns.keys()):
    plot_activation(act_name, ax[divmod(i, 2)], x)  # divmod(m, n) = m // n, m % n

fig.tight_layout()
../../_images/15fd50611e925c2711ddff8b9b556548ffd82c38ec24ba7cfdd2a50dcbafc02f.svg

Figure. Activations functions and its derivatives. Saturation regions are highlighted grey. Note that GELU has 4 regions where the others only has 3 (Tanh) or 2 (ReLU and SELU). SELU has a spike in its derivative at zero, which can cause issues with large LR.

Having outputs that range to \(+\infty\) for ReLU (adopted by SELU and GELU) increases network expressivity. On the other hand, Tanh is bounded in \([-1, 1]\) with gradients that saturate in the asymptotes. Moreover, a derivative of 1 for a large region helps to accelerate training and avoid vanishing gradients. Recall that if \(\boldsymbol{\mathsf{y}} = \varphi(\boldsymbol{\mathsf{z}})\), where \(\boldsymbol{\mathsf{z}} = \boldsymbol{\mathsf{x}}^\top\boldsymbol{\mathsf{w}}\) and \(\varphi\) is an activation, then

\[ \frac{\partial \mathcal L}{\partial \boldsymbol{\mathsf{x}}} = \left( \frac{\partial \mathcal L}{\partial \boldsymbol{\mathsf{y}}} \odot \varphi^{\prime}(\boldsymbol{\mathsf{z}}) \right) \boldsymbol{\mathsf{w}}^{\top}. \]

Here \(\varphi^{\prime}(\boldsymbol{\mathsf{z}})\) is a matrix with entries \(\varphi^{\prime}(\boldsymbol{\mathsf{z}}_{bj})\) and \(\odot\) denotes the Hadamard product.

Stacking layers effectively forms a chain of weights matrix multiplication with entries of the gradients scaled in between by derivatives of the activation function. For \(\varphi = \tanh\) the gradients are scaled down since \(\tanh^\prime(z) \leq 1\) for any \(z \in \mathbb R.\) Hence, creating a deep stack of fully-connected hidden layers can result in vanishing gradients at initialization. ReLU is interesting as it allows sparse gradient updates since \(\varphi^\prime(z) \in [0, 1].\)

def get_model(vocab_size: int,
              block_size: int,
              emb_size: int,
              width: int,
              num_linear: int,
              activation: str = "relu",
              layernorm: bool = False):

    act_fn = activation_fns[activation]
    layers = [
        nn.Embedding(vocab_size, emb_size), 
        nn.Flatten(),
        nn.Linear(block_size * emb_size, width)
    ]
    layers.append(nn.LayerNorm(width)) if layernorm else ""
    layers.append(act_fn())
    
    for _ in range(num_linear):
        layers.append(nn.Linear(width, width))
        layers.append(nn.LayerNorm(width)) if layernorm else ""
        layers.append(act_fn())

    layers.append(nn.Linear(width, vocab_size))
    layers.append(nn.LayerNorm(vocab_size)) if layernorm else ""
    model = nn.Sequential(*layers)
    return model

Running an experiment:

model_params = {
    "vocab_size": 29,
    "block_size": 3,
    "emb_size": 10,
    "width": 30,
    "num_linear": 6
}

train_params = {
    "steps": 2000,
    "batch_size": 32,
    "lr": 0.2
}


train_hooks = {"ln": {}, "un": {}}
trainers    = {"ln": {}, "un": {}}

for act in activation_fns.keys():
    for use_ln in [True, False]:
        init_fn = lambda w: nn.init.xavier_uniform_(w, gain=gain[act], generator=g())
        activation_stats = OutputStats()
        weight_grad_stats = WeightGradientStats()
        weight_update_ratio = WeightUpdateRatio(train_params["lr"])
        hooks = [
            {
                "fwd_hooks": [], 
                "bwd_hooks": [weight_grad_stats, weight_update_ratio], 
                "layer_types": (nn.Linear, nn.Embedding)
            },
            {
                "fwd_hooks": [activation_stats], 
                "bwd_hooks": [], 
                "layer_types": tuple(activation_fns.values())
            }
        ]
        model = init_model(get_model(activation=act, layernorm=use_ln, **model_params), init_fn)
        trainer = run(model, train_dataset, **train_params, hooks=hooks)

        k = "ln" if use_ln else "un"
        trainers[k][act] = trainer
        train_hooks[k][act] = {} if act not in train_hooks[k] else ""
        train_hooks[k][act]["outs"] = activation_stats
        train_hooks[k][act]["grad"] = weight_grad_stats
        train_hooks[k][act]["rate"] = weight_update_ratio

SELU does not train with large LR (can diverge even at lr=0.3) compared to other activations. Probably due to the spike in its derivative at zero. Tanh initial loss improves with LN consistent with the above discussion.

Hide code cell source
for l, ln in enumerate(["un", "ln"]):
    message = "train loss (normalized)" if ln == "ln" else "train loss (unnormalized)"
    print(message)
    for k, act in enumerate(activation_fns):
        trainer = trainers[ln][act]
        train_loss = trainer.logs["train"]["loss"]
        start = np.array(train_loss[:100]).mean()
        end = np.array(train_loss[-100:]).mean()
        diff = end - start
        print(f"    {act}    start: {start:.3f}    end: {end:.3f}    diff: {diff:.3f}")
    print()
train loss (unnormalized)
    tanh    start: 2.996    end: 2.503    diff: -0.493
    selu    start: 2.975    end: 2.503    diff: -0.471
    relu    start: 3.005    end: 2.521    diff: -0.485
    gelu    start: 3.015    end: 2.487    diff: -0.528

train loss (normalized)
    tanh    start: 2.968    end: 2.479    diff: -0.490
    selu    start: 2.919    end: 2.484    diff: -0.435
    relu    start: 3.002    end: 2.559    diff: -0.442
    gelu    start: 2.985    end: 2.505    diff: -0.480

Histograms. Activation and gradient distributions at the last step of the hidden linear layers:

Hide code cell source
step = -1
for key in ["outs", "grad"]:
    fig, ax = plt.subplots(2, len(activation_fns), figsize=(12, 6))
    for i, act in enumerate(activation_fns.keys()):
        h = 0
        for k, ln in enumerate(["un", "ln"]):
            records = train_hooks[ln][act][key].records
            with torch.no_grad():
                for j, m in enumerate(records.keys()):
                    layer_types = tuple(activation_fns.values()) if key == "outs" else nn.Linear
                    if isinstance(m, layer_types) and 0 < j < len(records) - 1:
                        data = records[m][step].reshape(-1).cpu().numpy()
                        h = max(h, np.abs(np.percentile(data, 99)))
                        sns.distplot(data, bins=50, ax=ax[k, i], label=f"Linear.{j+1}", hist=False)

            ax[k, i].set_xlim(-h, h)
            ax[k, i].set_title(f"{act} ({'unnormalized' if ln == 'un' else 'layernorm'})")
            ax[k, i].ticklabel_format(axis="x", style="sci", scilimits=(-2, 2))

    ax[k, i].legend(fontsize=8)
    fig.suptitle(f"{'Activation' if key == 'outs' else 'Gradient'} distribution at final step", fontsize=15)
    fig.tight_layout()
../../_images/f05001b6c1370addac6cef0de998564d8b6f2d8997c698c3f59236f50cb97d0e.svg../../_images/3386a79d35a4831b2c34c194eca466ca53b7a41b975576587faab842f14ddfd2.svg

Observe that Tanh outputs have a small range compared to the other activations. ReLU and GELU have tails on the positive values, while SELU is more symmetric with a heavy tail in the negative axis. LN makes the distributions more similar across layers. The plot of percentiles per layer below allow us to quantify this better:

Hide code cell source
step = -1
for key in ["outs", "grad"]:
    fig, ax = plt.subplots(1, len(activation_fns), figsize=(10, 3))
    for k, ln in enumerate(["un", "ln"]):
        alpha = 1.0 if ln == "ln" else 0.8
        linestyle = "dotted" if ln == "un" else "solid"
        linewidth = 2.0 if ln == "ln" else 1.5
        label = "LN" if ln == "ln" else "un"
        for i, act in enumerate(activation_fns.keys()):
            s = []
            u = []
            t = []
            records = train_hooks[ln][act][key].records
            with torch.no_grad():
                for j, m in enumerate(records.keys()):
                    layer_types = tuple(activation_fns.values()) if key == "outs" else nn.Linear
                    if isinstance(m, layer_types) and 0 < j < len(records) - 1:
                        data = records[m][step].reshape(-1).cpu().numpy()
                        s.append(np.percentile(data, 16))
                        u.append(np.percentile(data, 50))
                        t.append(np.percentile(data, 84))

            ax[i].plot(np.arange(len(s)) + 2, s, color=f"C{i}", alpha=alpha, linewidth=linewidth, linestyle=linestyle, label=label)
            ax[i].plot(np.arange(len(s)) + 2, t, color=f"C{i}", alpha=alpha, linewidth=linewidth, linestyle=linestyle)
            ax[i].plot(np.arange(len(s)) + 2, u, color="black", alpha=alpha, linewidth=linewidth, linestyle=linestyle)
            ax[i].fill_between(np.arange(len(s)) + 2, s, t, color=f"C{i}", alpha=0.2)

            ax[i].set_title(act)
            ax[i].set_xlabel("layer")
            ax[i].set_ylabel(r"$\sigma$")
            ylim = (-3, 3) if key == "outs" else (-0.05, 0.05)
            ax[i].set_ylim(*ylim)
            ax[i].grid(alpha=0.4, linestyle="dashed", linewidth=0.6)
            ax[i].legend(fontsize=8)
            ax[i].set_xticks(np.arange(len(s)) + 2)

    fig.suptitle(f"{'Activation' if key == 'outs' else 'Weight gradient'} distribution at final step", fontsize=15)
    fig.tight_layout()
../../_images/0faacb0a1cb4de6e09edfce2b7448b510e4eaa61ca9c818e6c3981c73be2c411.svg../../_images/0be19f68507ce1b9cf467680e124ff550f70b342f97745d6f78e2330f3b32197.svg

ReLU and GELU have decreasing gradients while Tanh and SELU have increasing gradients. In general, the propagation of weight gradients seem more constant across layers with LN. Next, observe that SELU (and GELU to a lower degree) which are not symmetric have shifting median activations.

This induces bias to inputs \(\boldsymbol{\mathsf{x}}\) of a layer to all have the same sign. Recall \(\frac{\partial\mathcal{L}}{\partial {\mathsf{w}}_{ij}} = {\mathsf{x}}_i \frac{\partial\mathcal{L}}{\partial {\mathsf{z}}_j},\) so that \(\frac{\partial\mathcal{L}}{\partial {\mathsf{w}}_{ij}}\) will have the same sign for all \(i,\) which can slow down learning. Also, if you have two independent random vectors as input, the inner product of its outputs tends to become large and positive:

n = 1000
w = torch.randn(20, 20)
a = torch.randn(n, 20) 
b = torch.randn(n, 20)
print(((a @ w) * (b @ w)).sum(dim=1).abs().mean())
print((((a + 0.2) @ w) * ((b + 0.2) @ w)).sum(dim=1).abs().mean())
tensor(107.3766)
tensor(110.3988)

This effect compounds with depth. Hence, deep networks affected by bias shift tend to predict the same label for all training examples at initialization. Bias shift is counteracted by LN, ensuring that the mean activation on each channel is zero across the layer.


Weight update ratio. LN improves separation and tightness of update ratio curves. SELU looks ideal (e.g. less fluctuations, tightness across hidden layers of same shape) these properties seem to be correlated with better loss. The significant improvement for the Tanh and SELU curves may explain why their performance improved with LN.

Hide code cell source
from typing import Literal

def plot_weight_update_ratio(key: Literal["un", "ln"], ylim=(-3.0, -0.5), legend=False):
    fig, ax = plt.subplots(1, len(activation_fns), figsize=(12, 5))
    for i, act in enumerate(activation_fns):
        records = train_hooks[key][act]["rate"].records
        title = f"{act} (normalized)" if key == "ln" else f"{act} (unnormalized)"
        for j, m in enumerate(records.keys()):
            ax[i].plot(moving_avg(np.log10(np.array(records[m])), window=0.05), label=m.__class__.__name__ + "." + str(len(records) - 1 - j))
            ax[i].grid(linestyle="dotted", alpha=0.6)
            ax[i].set_title(title, fontsize=12)
            ax[i].set_ylabel(r"$\log(\lambda \| \nabla \mathbf{w} \|\; / \;\| \mathsf{\mathbf{w}} \|)$")
            ax[i].set_xlabel("steps")
            ax[i].set_ylim(*ylim)

    if legend:
        ax[-1].legend(fontsize=8, loc="lower right")
        
    fig.tight_layout()


plot_weight_update_ratio("un", ylim=(-4.0, -1))
plot_weight_update_ratio("ln", ylim=(-4.0, -1), legend=True)
../../_images/21703f85e5aa6fa81caf4014170359da8fe7b03068df5367d4ed9ede1cd18d21.svg../../_images/1fcbddf63235e03d52527275977dcff0b8826ea73990b72c5484a9958d661c4d.svg

Appendix: Rank collapse#

Rank collapse happens when outputs of a layer lie in a small subspace of its output space. Xavier initialization by itself cannot prevent rank collapse. In fact, the difference between first and singular values of a product of Gaussian matrices sampled with Xavier initialization grows exponentially with depth (§3.2 of [DKB+20]). Activations help mitigate rank collapse by adding non-linearities in between weight matrices.

One way to measure rank numerically is by using singular values. Every matrix \(\boldsymbol{\mathsf{A}}\) has a singular value decomposition (SVD) \(\boldsymbol{\mathsf{A}} = \boldsymbol{\mathsf{U}} \boldsymbol{\Sigma} \boldsymbol{\mathsf{V}}^\top\) where \(\boldsymbol{\mathsf{U}}\) and \(\boldsymbol{\mathsf{V}}\) are orthogonal matrices whose columns form an orthonormal basis on the output and input spaces, respectively. And \(\boldsymbol{\Sigma}\) is a diagonal matrix containing singular values of \(\boldsymbol{\mathsf{A}}.\) It turns out that the number of nonzero singular values of a matrix is equal to its rank. In particular, the first singular value \(\sigma_1\) is equal to the operator norm which is roughly:

\[\sigma_1 = \max_{\lVert \boldsymbol{\mathsf{x}} \rVert = 1}\; \lVert \boldsymbol{\mathsf{A}} \boldsymbol{\mathsf{x}} \rVert.\]

Geometrically, this corresponds to the width of the output ellipsoid of a unit sphere in the input space transformed by \(\boldsymbol{\mathsf{A}}.\) The other singular values correspond to lengths of the other axes of the ellipsoid. For example, if only two singular values are nonzero, then the output ellipsoid, and hence the entire output space, is embedded in two dimensions.

Hide code cell source
rng = np.random.RandomState(0)
A = rng.randn(2, 2)
u, s, vT = np.linalg.svd(A)

N = 100
t = np.linspace(-1, 1, N)
unit_circle = np.stack([np.cos(2*np.pi*t), np.sin(2*np.pi*t)], axis=0)
outputs = A @ unit_circle

# Plot
plt.scatter(unit_circle[0, :], unit_circle[1, :], s=0.8, label="x")
plt.scatter(outputs[0, :], outputs[1, :], s=0.8, label="Ax")
plt.axis('scaled')

# Checking: max norm == s[0]
max_output_norm  = np.linalg.norm(outputs, axis=0).max()
min_output_norm  = np.linalg.norm(outputs, axis=0).min()
max_output_norm_ = np.linalg.norm(outputs, axis=0).argmax()

print("A=")
print(A)
print()
print(f'max ‖Ax‖ = {max_output_norm:.2f}  (‖x‖ = 1)')
print(f'min ‖Ax‖ = {min_output_norm:.2f}')
print(f'σ₁ = {s[0]:.2f}')
print(f'σ₂ = {s[1]:.2f}')


# Plotting singular vectors as axis of ellipse
plt.arrow(0, 0, s[0]*u[0, 0], s[0]*u[1, 0], width=0.01, length_includes_head=True)
plt.arrow(0, 0, s[1]*u[0, 1], s[1]*u[1, 1], width=0.01, length_includes_head=True)
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.legend()
plt.savefig('../../img/nn/05-singular-ellipsoid.svg')
plt.close()
A=
[[1.76405235 0.40015721]
 [0.97873798 2.2408932 ]]

max ‖Ax‖ = 2.75  (‖x‖ = 1)
min ‖Ax‖ = 1.29
σ₁ = 2.75
σ₂ = 1.29
../../_images/singular-ellipsoid.svg

Fig. 49 Singular values of a matrix \(\boldsymbol{\mathsf{A}}\) is equal the length of each axis of the output ellipsoid for an input unit sphere. The arrows in the ellipsoid are obtained using left singular vectors: \(\sigma_1\, \boldsymbol{\mathsf{u}}_1\) and \(\sigma_2\, \boldsymbol{\mathsf{u}}_2.\)#

Batch normalization. Below we test rank collapse with LN and batch normalization (BN) [IS15]. BN is similar to LN except that the normalization occurs in the batch dimension. As such, BN has different prediction algorithms for training and inference since we have to estimate batch statistics when making predictions with single instances at test time. This complicates using BN in practice (also see [WJ21]). However, in addition to having the same benefits as LN, BN avoids rank collapse [DKB+20] which will be shown in the following experiment.

def get_model(vocab_size: int,
              block_size: int,
              emb_size: int,
              width: int,
              num_linear: int,
              activation: str = "relu",
              norm: str = None):

    norm_layer = {
        "layernorm": nn.LayerNorm,
        "batchnorm": nn.BatchNorm1d
    }[norm] if norm else None

    act_fn = activation_fns[activation]
    layers = [
        nn.Embedding(vocab_size, emb_size), 
        nn.Flatten(),
        nn.Linear(block_size * emb_size, width)
    ]
    layers.append(norm_layer(width)) if norm else ""
    layers.append(act_fn())
    
    for _ in range(num_linear):
        layers.append(nn.Linear(width, width))
        layers.append(norm_layer(width)) if norm else ""
        layers.append(act_fn())

    layers.append(nn.Linear(width, vocab_size))
    layers.append(norm_layer(vocab_size)) if norm else ""
    model = nn.Sequential(*layers)
    return model

Hook for calculating singular values:

class SingularValuesRatio(HookHandler):
    def hook_fn(self, module, input, output):
        self.records[module] = self.records.get(module, [])
        s = torch.linalg.svd(output)[1] # singular values
        self.records[module].append((s[0] / s.sum()).item())

Here we are careful not to use large LR. Dying units causes rank collapse which may lead us to incorrectly conclude that LN mitigates rank collapse better than an unnormalized network (since LN prevents dying units).

# Note: act = relu. But same results for tanh and gelu.
sv_hooks = {}
message = ["\nloss (last 100 steps)"]
model_params["num_linear"] = 6
train_params["lr"] = 0.3
train_params["steps"] = 3000

for norm in ["batchnorm", "layernorm", None]:
    
    init_fn = lambda w: nn.init.xavier_uniform_(w, gain=np.sqrt(2), generator=g())
    model = init_model(get_model(norm=norm, **model_params), init_fn)
    sv_ratio = SingularValuesRatio()
    hooks = [{"bwd_hooks": [], "fwd_hooks": [sv_ratio], "layer_types": (nn.Linear,)}]
    trainer = run(model, train_dataset, **train_params, hooks=hooks)

    sv_hooks[norm] = sv_ratio
    message.append(f"    {norm}: \t{sum(trainer.logs['train']['loss'][-100:]) / 100:.3f}")

print("\n".join(message))
loss (last 100 steps)
    batchnorm: 	2.466
    layernorm: 	2.492
    None: 	2.499
Hide code cell source
fig, ax = plt.subplots(1, len(sv_hooks), figsize=(10, 4))
for i, norm in enumerate(sv_hooks):
    records = sv_hooks[norm].records
    for j, m in enumerate(records.keys()):
        ax[i].plot(moving_avg(records[m], window=0.0005), label=m.__class__.__name__ + "." + str(j), color=f"C{j}", alpha=min(0.5 + (0.5 / len(records.keys())) * j, 1.0))
        ax[i].set_ylim(0, 1)
        ax[i].grid(linestyle="dashed", alpha=0.5)
        ax[i].set_xlabel("steps")
        ax[i].set_title(norm)

ax[0].set_ylabel(r"${\sigma}_1\; /\; \mathbf{\sigma}$.sum()")
ax[-1].set_title("Unnormalized")
ax[-1].legend(fontsize=7)
fig.tight_layout()
../../_images/8262d3d33c4fb121b234b116301adce4d49fa13899ad75527626691e213031ad.svg

Remark. Degree of rank collapse increases with depth. This makes sense for linear networks since \(\text{rank}\,\boldsymbol{\mathsf{A}}\boldsymbol{\mathsf{X}} \leq \text{rank}\,\boldsymbol{\mathsf{X}}\) (see eqn. (8) in [FZH+22] for general networks). LN does not fully diminish rank collapse. In the next notebook, we will look at residual connections which help to mitigate rank collapse without having to use BN.

Appendix: Bias initialization#

Recall that for our experiments biases are initialized to zero. This makes sense for hidden layers. However, the final layer weights should be initialized correctly. For example, given an imbalanced dataset with binary label ratio 1:10, it makes sense to calibrate the bias of your logits such that your network predicts probability of \(0.1\) for the minority class at initialization. Setting the biases of the logits correctly can speed up convergence since the first few training steps is spent just learning the bias.

This technique is demonstrated below. The following dataset has a 1:7 class imbalance:

import torch 
torch.manual_seed(2)

N = 1500    # inner circle
M = 6 * N   # outer circle
noise = lambda n, e: torch.randn(n, 2) * e
s = 2 * torch.pi * torch.rand(N, 1)
t = 2 * torch.pi * torch.rand(M, 1)

x0 = torch.cat([0.1 * torch.cos(s), 0.1 * torch.sin(s)], dim=1) + noise(N, 0.05)
x1 = torch.cat([1.0 * torch.cos(t), 1.0 * torch.sin(t)], dim=1) + noise(M, 0.1)
y0 = (torch.ones(N,) * 0).long()
y1 = (torch.ones(M,) * 1).long()
Hide code cell source
plt.scatter(x0[:, 0], x0[:, 1], s=2.0, label=0, color="C0")
plt.scatter(x1[:, 0], x1[:, 1], s=2.0, label=1, color="C1")
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.legend()
plt.axis('equal');
../../_images/3e8c48d32ba637658a65ed272ad048546862243046c08f4498eb56656d02ec94.svg
x = torch.cat([x0, x1])
y = torch.cat([y0, y1])
ds = torch.utils.data.TensorDataset(x, y)
train_loader = lambda: DataLoader(ds, batch_size=64, shuffle=True)
y.float().mean()    # 6 / 7 = 0.8571428571428571
tensor(0.8571)

Training our baseline model:

model = lambda: nn.Sequential(
    nn.Linear(2, 3), nn.ReLU(),
    nn.Linear(3, 3), nn.ReLU(),
    nn.Linear(3, 2)
)

RANDOM_SEED = 0
set_seed(RANDOM_SEED)
model_base = model()
optim = torch.optim.Adam(model_base.parameters(), lr=0.1)
losses_base = []
for epoch in range(2):
    for xb, yb in train_loader():
        logits = model_base(xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward()
        optim.step()
        optim.zero_grad()
        losses_base.append(loss.item())

Next, initialize our test model with careful bias initialization on logits:

set_seed(RANDOM_SEED)
model_test = model()
logits = list(model_test.modules())[-1]
logits.bias.data = torch.tensor([0.0, np.log(6)]).float()   # w0 = 0, w1 s.t. exp(w1) / (exp(0) + exp(w1)) = 6 / 7
F.softmax(model_test(xb), dim=1).mean(dim=0)
tensor([0.1388, 0.8612], grad_fn=<MeanBackward1>)

Training the test model:

optim = torch.optim.Adam(model_test.parameters(), lr=0.1)
losses_test = []
for epoch in range(2):
    for xb, yb in train_loader():
        logits = model_test(xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward()
        optim.step()
        optim.zero_grad()
        losses_test.append(loss.item())

Results. Test model has faster convergence, i.e. better starting and final loss. Consistent over multiple random seeds (3 or 4).

Hide code cell source
plt.plot(losses_base, label="base")
plt.plot(losses_test, label="test")
plt.legend();
../../_images/b8d6256405046991e5149347cbb52094eb9d3578c6f79fcaa214ef5f142682e5.svg
print(f"base: [{losses_base[0]:.3e}, {losses_base[-1]:.3e}]")
print(f"test: [{losses_test[0]:.3e}, {losses_test[-1]:.3e}]")
base: [6.208e-01, 4.768e-07]
test: [3.772e-01, 0.000e+00]