Backprop Through Time (BPTT)

Recall that the state is updated at each time step with the current input. Thus, we have to track the dependencies across time steps where the RNN parameters are shared. This is called Backprogation Through Time (BPTT) or BP for sequence models. Hopefully, this discussion will bring some precision to the notion of vanishing and exploding gradients.

This procedure requires us to expand (or unroll) the computational graph of an RNN one time step at a time. The unrolled RNN is essentially a feedforward neural network with the special property that the same parameters are repeated throughout the unrolled network, appearing at each time step. Then, we can apply the usual BP through the unrolled net. In particular, we want to see causality in the equations, i.e. state at time \(t\) only influences future time steps.

For long sequences, e.g. text sequences containing over a thousand tokens, BP across many layers poses problems both from a computational (too much memory to compress in a single state vector) and optimization standpoint (numerical instability). Here input from the first step passes through \(T\) matrix products before arriving at the output. Similarly, we expect \(T\) matrix products are required to compute the gradient at the first time step.

../../../_images/04-rnn-backprop.svg

Fig. 59 RNN cell backpropation. Note that the matrices \(\boldsymbol{\mathsf{W}}, \boldsymbol{\mathsf{U}},\) and \(\boldsymbol{\mathsf{V}}\) are shared across time steps.

Recall:

\[\begin{split} \begin{aligned} \boldsymbol{\mathsf{H}}_t &= f(\boldsymbol{\mathsf{X}}_t \boldsymbol{\mathsf{U}} + \boldsymbol{\mathsf{H}}_{t-1} \boldsymbol{\mathsf{W}} + \boldsymbol{\mathsf{b}}) \\ \boldsymbol{\mathsf{Y}}_t &= \boldsymbol{\mathsf{H}}_t \boldsymbol{\mathsf{V}} + \boldsymbol{\mathsf{c}} \\ \boldsymbol{\mathsf{H}}_{t+1} &= f(\boldsymbol{\mathsf{X}}_{t+1} \boldsymbol{\mathsf{U}} + \boldsymbol{\mathsf{H}}_{t} \boldsymbol{\mathsf{W}} + \boldsymbol{\mathsf{b}}). \end{aligned} \end{split}\]

Assume incoming gradients \(\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{Y}}_t}\) and \(\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{H}}_{t+1}}\) from the next layer. We start by calculating the gradient with respect to \(\boldsymbol{\mathsf{V}}.\) Here we abstract the product between two tensors on appropriate indices by using the \(\text{prod}\) notation. The exact formula can be recovered with tensor index notation. Let \(f\) be an activation function. Upper case indicates that a tensor’s first dimension is the batch dimension when applicable. Then,

\[\begin{split} \begin{aligned} \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{V}}}}_{(h, q)} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{Y}}_t}, \frac{\partial \boldsymbol{\mathsf{Y}}_t}{\partial \boldsymbol{\mathsf{V}}}\right) = \sum_{t=1}^T \underbrace{\boldsymbol{\mathsf{H}}_t^\top \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{Y}}_t}}_{(h, B) \,\times\, (B, q)} \\ \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{c}}}}_{(1, q)} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{Y}}_t}, \frac{\partial \boldsymbol{\mathsf{Y}}_t}{\partial \boldsymbol{\mathsf{c}}}\right) = \sum_{t=1}^T \underbrace{\boldsymbol{\mathsf{1}}^\top \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{Y}}_t}}_{(1, B) \,\times\, (B, q)} \end{aligned} \end{split}\]

Next, we calculate the gradients flowing to \(\boldsymbol{\mathsf{H}}_t\) which will be our gateway to compute gradients of \(\boldsymbol{\mathsf{W}}\), \(\boldsymbol{\mathsf{U}}\), and \(\boldsymbol{\mathsf{b}}\), and finally \(\boldsymbol{\mathsf{X}}_t.\) Note that \(\boldsymbol{\mathsf{H}}_t\) affects not only \(\boldsymbol{\mathsf{Y}}_t\), but also future \(\boldsymbol{\mathsf{Y}}_{t^\prime}\) via \(\boldsymbol{\mathsf{H}}_{t^\prime}\) for \(t^\prime > t.\) But in terms of direct dependence, the nodes that immediately depend on \(\boldsymbol{\mathsf{H}}_t\) are \(\boldsymbol{\mathsf{Y}}_t\) and \(\boldsymbol{\mathsf{H}}_{t+1}\) (Fig. 59). Let \(\boldsymbol{\mathsf{Z}}_{t+1} = \boldsymbol{\mathsf{X}}_{t+1} \boldsymbol{\mathsf{U}} + \boldsymbol{\mathsf{H}}_{t} \boldsymbol{\mathsf{W}} + \boldsymbol{\mathsf{b}}.\) Then,

\[\begin{split} \begin{aligned} \underbrace{\frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t}}_{(B, h)} &= \text{prod}\left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{Y}}_t}, \frac{\partial\boldsymbol{\mathsf{Y}}_t}{\partial\boldsymbol{\mathsf{H}}_t} \right) + \text{prod}\left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t + 1}}, \frac{\partial\boldsymbol{\mathsf{H}}_{t + 1}}{\partial\boldsymbol{\mathsf{Z}}_{t+1}}, \frac{\partial\boldsymbol{\mathsf{Z}}_{t+1}}{\partial\boldsymbol{\mathsf{H}}_{t }} \right) \\ &= \underbrace{\frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{Y}}_t}\, \boldsymbol{\mathsf{V}}^\top}_{(B, q)\,\times\,(q, h)} + \underbrace{ \left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t + 1}} \odot f^\prime(\boldsymbol{\mathsf{Z}}_{t+1}) \right) \boldsymbol{\mathsf{W}}^\top }_{((B, h)\, \cdot \, (B, h)) \, \times \, (h, h)} \end{aligned} \end{split}\]

To make sense of this, recall \(\boldsymbol{\mathsf{V}}\) and \(\boldsymbol{\mathsf{W}}\) acts on \(\boldsymbol{\mathsf{H}}_t\) from the left. Hence, when we take its transpose, multiplying a tensor to the right of \(\frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t}\), results in a summation along the dimension containing information about the state \(\boldsymbol{\mathsf{H}}_t.\) Similarly, the orientation of the products within the expression are also correct.

Note that the above expression is recursive, we should be able to get a closed form expression from terms in time step \(t, t+1, \ldots, T.\) For tractability, let’s assume we have no nonlinearity, or \(f = \text{Id},\) so that \(f^\prime(\boldsymbol{\mathsf{Z}}_{t + 1}) = \mathbf{1}_{(B, h)}\). Then, we can write:

\[ \begin{aligned} a_t = \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t}, \quad b_t = \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{Y}}_t}\, \boldsymbol{\mathsf{V}}^\top, \quad c_t = c = \boldsymbol{\mathsf{W}}^\top \end{aligned} \]

with \(a_{T+1} = 0\) and \(a_T = b_T.\) Hence,

\[\begin{split} \begin{aligned} a_t &= b_t + a_{t+1} c_t \\ &= b_t + (b_{t + 1} + a_{t + 2} c_{t + 1}) c_t \\ &= b_t + (b_{t + 1} + (b_{t + 2} + a_{t + 3} c_{t + 2}) c_{t + 1}) c_t \\ &= b_t + b_{t + 1} c_t + b_{t + 2}c_{t + 1}c_t + a_{t + 3} c_{t + 2}c_{t + 1}c_t \\ &\vdots \\ &= \sum_{\kappa = 0}^{T - t} b_{t + \kappa}\, c^{\kappa}. \end{aligned} \end{split}\]

Thus,

(7)\[ \boxed{ \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t} = \sum_{\kappa = 0}^{T - t} \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{Y}}_{t + \kappa}}\, \boldsymbol{\mathsf{V}}^\top \left(\boldsymbol{\mathsf{W}}^\top\right)^{\kappa}. } \]

This formula is similar to that for gradient flow across the layers of a deep MLP network, but here the depth is along sequence length. The terms in the sum correspond to paths of increasing path lengths \(\kappa = 0, \ldots, T - t\) from the current time step \(t.\) Finally, observe that the change in loss due to the current time step is only due to its effect on future time steps, not on the past, so we have a notion of causality in RNNs.


../../../_images/04-rnn-bptt.svg

Fig. 60 Gradient transformation graph to get \(\frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t}\) at time step \(t\) with increasing path length \(\kappa.\) Each edge is modulated by \(f^\prime\) and \(\boldsymbol{\mathsf{W}}^\top.\)

Finally, let’s calculate the rest of the parameter gradients. Then,

\[\begin{split} \begin{aligned} \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{U}}}}_{(d, h)} &= \sum_{t=1}^T \text{prod}\left( \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{H}}_t}, \frac{\partial \boldsymbol{\mathsf{H}}_t}{\partial \boldsymbol{\mathsf{Z}}_t}, \frac{\partial \boldsymbol{\mathsf{Z}}_t}{\partial \boldsymbol{\mathsf{U}}} \right) = \sum_{t=1}^T \underbrace{ \boldsymbol{\mathsf{X}}_{t}^\top \left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t}} \odot f^\prime(\boldsymbol{\mathsf{Z}}_{t}) \right) }_{(d, B) \,\times\, ((B, h) \, \cdot\, (B, h))} \\\\ \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{W}}}}_{(h, h)} &= \sum_{t=1}^T \text{prod}\left( \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{H}}_t}, \frac{\partial \boldsymbol{\mathsf{H}}_t}{\partial \boldsymbol{\mathsf{Z}}_t}, \frac{\partial \boldsymbol{\mathsf{Z}}_t}{\partial \boldsymbol{\mathsf{W}}} \right) = \sum_{t=1}^T \underbrace{\boldsymbol{\mathsf{H}}_{t-1}^\top \left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t}} \odot f^\prime(\boldsymbol{\mathsf{Z}}_{t}) \right)}_{(h, B) \,\times\, ((B, h) \, \cdot\, (B, h))} \\\\ \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{b}}}}_{(1, h)} &= \sum_{t=1}^T \text{prod}\left( \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{H}}_t}, \frac{\partial \boldsymbol{\mathsf{H}}_t}{\partial \boldsymbol{\mathsf{Z}}_t}, \frac{\partial \boldsymbol{\mathsf{Z}}_t}{\partial \boldsymbol{\mathsf{b}}} \right) = \sum_{t=1}^T \underbrace{\boldsymbol{\mathsf{1}}^\top \left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t}} \odot f^\prime(\boldsymbol{\mathsf{Z}}_{t}) \right) }_{(1, B) \,\times\, ((B, h) \, \cdot \, (B, h))}. \end{aligned} \end{split}\]

The gradient to inputs may be also relevant (e.g. deep RNNs):

\[ \underbrace{\frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{X}}_t}}_{(B, d)} = \text{prod}\left( \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mathsf{H}}_t}, \frac{\partial \boldsymbol{\mathsf{H}}_t}{\partial \boldsymbol{\mathsf{Z}}_t}, \frac{\partial \boldsymbol{\mathsf{Z}}_t}{\partial \boldsymbol{\mathsf{X}}_t} \right) = \underbrace{ \left( \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_{t}} \odot f^\prime(\boldsymbol{\mathsf{Z}}_{t}) \right) \boldsymbol{\mathsf{U}}^\top }_{((B, h) \, \cdot \, (B, h)) \, \times \, (h, d)} \]

Hence, the key quantity that affects the numerical stability is \(\frac{\partial\mathcal{L}}{\partial\boldsymbol{\mathsf{H}}_t}\) (7).


Manual verification

import torch
import torch.nn.functional as F

B, T, V, h = 32, 5, 30, 64

# forward pass
O = torch.randint(low=0, high=V, size=(T, B))    # (T, B)
x = torch.randint(low=0, high=V, size=(T, B))    # (T, B)
X = F.one_hot(x, num_classes=V).float()          # (T, B, V)
X.requires_grad = True

model = LanguageModel(RNN)(V, h, V)   

W  = model.cell.W                                # (h, h)
U  = model.cell.U                                # (V, h)
b  = model.cell.b                                # (h,)
Vt = model.linear.weight                         # (V, h)
c  = model.linear.bias                           # (V,)
Y  = model(X)                                    # (T, B, V)
H  = model.cell(X)[0]                            # (T, B, h)
J  = 1 - H * H                                   # (T, B, h)

# backprop
X.retain_grad()
Y.retain_grad()
loss = F.cross_entropy(Y.transpose(1, 2), O)
loss.backward(retain_graph=True)

Smoke test:

assert ((H @ Vt.T + c) - Y).abs().max() == 0.0

Calculating the gradients by hand:

dY = Y.grad
dH = [None] * T
dH[T - 1] = dY[T - 1] @ Vt
for t in range(T - 2, -1, -1):
    dH[t] = dY[t] @ Vt + (dH[t + 1] * J[t + 1]) @ W.T
    
dH = torch.stack(dH)
dZ = dH * J
dc = torch.einsum('tbj -> j', dY)
dV = torch.einsum('tbh, tbv -> hv', H, dY)
dU = torch.einsum('tbv, tbh -> vh', X, dH * J)
db = torch.einsum('tbh -> h', dH * J)
dX = torch.einsum('tbh, vh -> tbv', dH * J, U)
dW = sum([H[t-1].T @ (dH * J)[t] for t in range(1, T)], torch.zeros((h, h)))

Calculating absolute errors versus autograd:

def compare(name, dt, t):
    exact = torch.all(dt == t.grad).item()
    approx = torch.allclose(dt, t.grad, rtol=1e-5)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{name:<3s} | exact: {str(exact):5s} | approx: {str(approx):5s} | maxdiff: {maxdiff:.2e}')
    return approx

assert compare('dV', dV.T, Vt)
assert compare('dc', dc, c)
assert compare('dU', dU, U)
assert compare('dW', dW, W)
assert compare('db', db, b)
assert compare('dX', dX, X)
dV  | exact: False | approx: True  | maxdiff: 6.52e-09
dc  | exact: True  | approx: True  | maxdiff: 0.00e+00
dU  | exact: False | approx: True  | maxdiff: 9.31e-10
dW  | exact: False | approx: True  | maxdiff: 4.66e-10
db  | exact: False | approx: True  | maxdiff: 3.73e-09
dX  | exact: True  | approx: True  | maxdiff: 0.00e+00