Gated Recurrent Unit (GRU)
Recall that the hidden state in the LSTM cell is essentially the internal state filtered based on the current input (i.e. with the output gate \(\boldsymbol{\Gamma}^o_t\)). Modifications such as peephole connections have been proposed to include the pure internal cell state in the computation. A further simplification is the GRU [CvMG+14] which combines the feedback mechanism into a single hidden state vector that is appropriately gated. This significantly reduces the number of parameters compared to LSTM with similar performance.
Gating mechanism
Similar to an LSTM cell, GRU uses gates to control update and reset of memory over time:
Gate |
Symbol |
Controls |
---|---|---|
Reset |
\(\boldsymbol{\Gamma}^r\) |
How much of the previous state is incorporated into the new state |
Update |
\(\boldsymbol{\Gamma}^z\) |
How much of the old state is retained, versus the candidate state |
This is similar to the LSTM cell with the input gate corresponding to the reset gate and the forget gate corresponding to the update gate. There is no output gate; the current state is calculated as a weighted average of the previous state and the candidate state.
GRU equations
GRU computation is defined as follows. Let \(h\) be the width of the GRU and \(B\) be the batch size. Then all entries of the gates in \((0, 1) \subset \mathbb{R}\) which can be interpreted as a \(B \times h\) smooth switch:
The hidden state \({\boldsymbol{\mathsf{C}}}_t\) with shape \(B \times h\) is updated as
where \(\tilde{\boldsymbol{\mathsf{H}}}_t\) is called the candidate hidden state. Applying \(\tanh\) ensures that elements of \({\boldsymbol{\mathsf{H}}}_t\) are in \((-1, 1).\) The computation is illustrated in Fig. 62:
First, let us look at the reset gate. When \({\boldsymbol{\Gamma}^r_t} = 1\), the entire hidden state is used to calculate the candidate hidden state. On the other hand, if \({\boldsymbol{\Gamma}^r_t} = 0\), then the hidden state is reset and we start with a candidate state that depends precisely on the current input. However, the update gate can still ignore the candidate state with \({\boldsymbol{\Gamma}^z_t} = 1\), so that past inputs will potentially still have an effect on future outputs, i.e. \(\boldsymbol{\mathsf{h}}_t \approx \boldsymbol{\mathsf{h}}_{t - 1}.\) Similarly, the GRU can accrue information across many time steps this way by keeping the reset and update gates open. Once it closes, the accumulated information in the candidate state takes effect. In general, this is how long-term dependencies are handled by the unit. The gradient for \(\boldsymbol{\mathsf{h}}_{t-1}\) can be calculated by following nodes that depend on it, there are four such paths, but one path involving \(\frac{\partial{\boldsymbol{\mathsf{h}}_t}}{\partial{\boldsymbol{\mathsf{h}}_{t-1}}}\) only involves scaling with \({\boldsymbol{\Gamma}^z_t} \, \odot\) instead of matrix multiplication.
Code implementation
class GRU(RNNBase):
def __init__(self, inputs_dim: int, hidden_dim: int):
super().__init__(inputs_dim, hidden_dim)
self.hidden_dim = hidden_dim
self.inputs_dim = inputs_dim
self.R = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
self.Z = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
self.G = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
def init_state(self, x):
B = x.shape[1]
return torch.zeros(B, self.hidden_dim, device=x.device)
def _step(self, x_t, state):
h = state
x_gate = torch.cat([x_t, h], dim=1)
r = torch.sigmoid(self.R(x_gate))
z = torch.sigmoid(self.Z(x_gate))
g = torch.tanh(self.G(torch.cat([x_t, r * h], dim=1)))
h = z * h + (1 - z) * g
return h, h
def compute(self, x, state):
T = x.shape[0]
outs = []
for t in range(T):
out, state = self._step(x[t], state)
outs.append(out)
return torch.stack(outs), state
Shapes test:
B, T, d, h = 32, 5, 10, 20
x = torch.randn(T, B, d)
gru = GRU(d, h)
outs, H = gru(x) # same with PyTorch:
# https://pytorch.org/docs/stable/generated/torch.nn.GRU.html
assert outs.shape == (T, B, h)
assert H.shape == (B, h)
assert all(H[0] == outs[-1][0])
Model training
from torch.utils.data import random_split
data, tokenizer = TimeMachine().build()
T = 30
BATCH_SIZE = 128
VOCAB_SIZE = tokenizer.vocab_size
dataset = SequenceDataset(data, seq_len=T, vocab_size=VOCAB_SIZE)
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # also sampled
print("preds per epoch")
print("train:", f"{len(train_loader) * BATCH_SIZE * T: .3e}")
print("valid:", f"{len(valid_loader) * BATCH_SIZE * T: .3e}")
preds per epoch
train: 4.182e+06
valid: 1.048e+06
from tqdm.notebook import tqdm
DEVICE = "cpu"
LR = 0.01
EPOCHS = 5
MAX_NORM = 1.0
model = LanguageModel(GRU)(VOCAB_SIZE, 64, VOCAB_SIZE)
model.to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=LR)
train_losses = []
valid_losses = []
for e in tqdm(range(EPOCHS)):
for t, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
x, y = x.to(DEVICE), y.to(DEVICE)
loss = train_step(model, optim, x, y, MAX_NORM)
train_losses.append(loss)
if t % 5 == 0:
xv, yv = next(iter(valid_loader))
xv, yv = xv.to(DEVICE), yv.to(DEVICE)
valid_losses.append(valid_step(model, xv, yv))
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
backend_inline.set_matplotlib_formats("svg")
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label="train")
plt.plot(np.array(range(1, len(valid_losses) + 1)) * 5, valid_losses, label="valid")
plt.grid(linestyle="dotted", alpha=0.6)
plt.ylabel("loss")
plt.xlabel("step")
plt.legend();
np.array(valid_losses[-50:]).mean()
1.3439158058166505
Text generation
textgen = TextGenerator(model, tokenizer, device="cpu")
s = [textgen.predict("thank y", num_preds=2, temperature=0.4) for i in range(20)]
(np.array(s) == "thank you").mean()
0.95
warmup = "mr williams i underst"
text = []
temperature = []
for i in range(1, 6):
t = 0.20 * i
s = textgen.predict(warmup, num_preds=100, temperature=t)
text.append(s)
temperature.append(t)
Show code cell source
import pandas as pd
from IPython.display import display
pd.set_option("display.max_colwidth", None)
df = pd.DataFrame({"temp": [f"{t:.1f}" for t in temperature], "text": text})
df = df.style.set_properties(**{"text-align": "left"})
display(df)
temp | text | |
---|---|---|
0 | 0.2 | mr williams i understand the time traveller the time traveller the palace of the start of the time traveller the palace o |
1 | 0.4 | mr williams i understould find the lamp to the little people were a passed the bright race of face the force and said the |
2 | 0.6 | mr williams i understards me fiesh in the only to the sun seen and sure me out to the time traveller as i carks of the th |
3 | 0.8 | mr williams i understood my diront of the rose they reed perhaps the ancessments was s the behind madzes to the time trav |
4 | 1.0 | mr williams i understail his vate and they nights began to make what continuar when it feeckean there were too absender t |