Deep Recurrent Networks
To increase network complexity, we can make the RNNs “vertically deep” at each time step (i.e. deep in the usual FFN sense). Recall that RNNs are “horizontally deep” in the sense that early inputs at influence outputs and state at later time steps. But having depth in the usual sense allows learning higher-order state vectors \(\boldsymbol{\mathsf{H}}^{\ell}_t.\)
Fig. 63 Deep RNN architecture. Observe that it requires \(L\) state vectors at each step.
Code implementation
from torch.utils.data import random_split
data, tokenizer = TimeMachine().build()
T = 30
BATCH_SIZE = 128
VOCAB_SIZE = tokenizer.vocab_size
dataset = SequenceDataset(data, seq_len=T, vocab_size=VOCAB_SIZE)
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # also sampled
print("preds per epoch")
print("train:", f"{len(train_loader) * BATCH_SIZE * T: .3e}")
print("valid:", f"{len(valid_loader) * BATCH_SIZE * T: .3e}")
preds per epoch
train: 4.182e+06
valid: 1.048e+06
Deep RNN. Note that at each time step, we compute \(\boldsymbol{\mathsf{o}}_t^\ell, \boldsymbol{\mathsf{h}}_t^\ell = f(\boldsymbol{\mathsf{o}}_t^{\ell - 1}, \boldsymbol{\mathsf{h}}_{t-1}^\ell)\) where \(\boldsymbol{\mathsf{o}}_t^0 = \boldsymbol{\mathsf{x}}_t.\) This can be visualized as the output vector moving depthwise, while we feed the previous states horizontally as in Fig. 63. Hence, we can calculate all outputs at the same depth, before pushing it to the next layer (i.e. fix \(\ell\) and iterate over \(t\)).
from functools import partial
class DeepRNN(RNNBase):
def __init__(self,
cell: Type[RNNBase],
inputs_dim: int, hidden_dim: int,
num_layers: int, # (!)
**kwargs,
):
super().__init__(inputs_dim, hidden_dim)
self.num_layers = num_layers
self.layers = nn.ModuleList()
for l in range(num_layers):
if l == 0:
self.layers.append(cell(inputs_dim, hidden_dim, **kwargs))
else:
self.layers.append(cell(hidden_dim, hidden_dim, **kwargs))
def init_state(self, x):
"""Defer state init to each cell with state=None."""
return [None] * self.num_layers
def compute(self, x, state):
T = x.shape[0]
out = x
for l, cell in enumerate(self.layers):
out, state[l] = cell(out, state[l])
return out, state
Deep = lambda cell: partial(DeepRNN, cell)
Remark. You can swap GRU out with any recurrent cell (e.g. RNN, LSTM).
gru = Deep(GRU)(5, 2, num_layers=3)
print(gru)
DeepRNN(
(layers): ModuleList(
(0): GRU(
(R): Linear(in_features=7, out_features=2, bias=True)
(Z): Linear(in_features=7, out_features=2, bias=True)
(G): Linear(in_features=7, out_features=2, bias=True)
)
(1-2): 2 x GRU(
(R): Linear(in_features=4, out_features=2, bias=True)
(Z): Linear(in_features=4, out_features=2, bias=True)
(G): Linear(in_features=4, out_features=2, bias=True)
)
)
)
Shape test:
x = torch.randn(10, 32, 5)
outs, state = gru(x)
assert len(state) == 3
assert outs.shape == (10, 32, 2)
assert state[0].shape == (32, 2)
Correctness for layers more than 1:
gru = Deep(GRU)(5, 2, num_layers=3)
gru_torch = nn.GRU(5, 2, num_layers=3)
for net in [gru, gru_torch]:
for name, p in net.named_parameters():
if "bias" in name:
p.data.fill_(0.0)
else:
p.data.fill_(1.0)
error = torch.max(torch.abs(gru(x)[0] - gru_torch(x)[0]))
print(error)
assert error < 1e-5
tensor(6.2585e-07, grad_fn=<MaxBackward1>)
Recovers base case (i.e. num_layers=1
):
gru = Deep(GRU)(5, 2, num_layers=1)
gru_base = GRU(5, 2)
for net in [gru, gru_base]:
for name, p in net.named_parameters():
if "bias" in name:
p.data.fill_(0.0)
else:
p.data.fill_(1.0)
error = torch.max(torch.abs(gru_base(x)[0] - gru(x)[0]))
print(error)
assert error < 1e-6
tensor(0., grad_fn=<MaxBackward1>)
Model training
Common RNN layer widths are in the range [64, 2056] while depth is in [1, 8].
VOCAB_SIZE = tokenizer.vocab_size
model = LanguageModel(Deep(GRU))(VOCAB_SIZE, 64, vocab_size=VOCAB_SIZE, num_layers=3)
model
RNNLanguageModel(
(cell): DeepRNN(
(layers): ModuleList(
(0): GRU(
(R): Linear(in_features=92, out_features=64, bias=True)
(Z): Linear(in_features=92, out_features=64, bias=True)
(G): Linear(in_features=92, out_features=64, bias=True)
)
(1-2): 2 x GRU(
(R): Linear(in_features=128, out_features=64, bias=True)
(Z): Linear(in_features=128, out_features=64, bias=True)
(G): Linear(in_features=128, out_features=64, bias=True)
)
)
)
(linear): Linear(in_features=64, out_features=28, bias=True)
)
from tqdm.notebook import tqdm
DEVICE = "cpu"
LR = 0.01
EPOCHS = 5
MAX_NORM = 1.0
model.to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=LR)
train_losses = []
valid_losses = []
for e in tqdm(range(EPOCHS)):
for t, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
x, y = x.to(DEVICE), y.to(DEVICE)
loss = train_step(model, optim, x, y, MAX_NORM)
train_losses.append(loss)
if t % 5 == 0:
xv, yv = next(iter(valid_loader))
xv, yv = xv.to(DEVICE), yv.to(DEVICE)
valid_losses.append(valid_step(model, xv, yv))
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
backend_inline.set_matplotlib_formats("svg")
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label="train")
plt.plot(np.array(range(1, len(valid_losses) + 1)) * 5, valid_losses, label="valid")
plt.grid(linestyle="dotted", alpha=0.6)
plt.ylabel("loss")
plt.xlabel("step")
plt.legend();
Best performance this chapter:
np.array(valid_losses[-50:]).mean()
1.2173951292037963
Text generation
textgen = TextGenerator(model, tokenizer, device="cpu")
s = [textgen.predict("thank y", num_preds=2, temperature=0.4) for i in range(20)]
(np.array(s) == "thank you").mean()
0.05
warmup = "mr williams i underst"
text = []
temperature = []
for i in range(1, 6):
t = 0.20 * i
s = textgen.predict(warmup, num_preds=100, temperature=t)
text.append(s)
temperature.append(t)
Show code cell source
import pandas as pd
from IPython.display import display
pd.set_option("display.max_colwidth", None)
df = pd.DataFrame({"temp": [f"{t:.1f}" for t in temperature], "text": text})
df = df.style.set_properties(**{"text-align": "left"})
display(df)
temp | text | |
---|---|---|
0 | 0.2 | mr williams i understood and i had seen of the black and the morlocks and the time traveller and the black still assured |
1 | 0.4 | mr williams i understood as the sun in the palace of man had a shadows of an inces and doubted the lawn in the laboratory |
2 | 0.6 | mr williams i understood the touch and only the but the cares was one hundred and been limple of more and the days and sh |
3 | 0.8 | mr williams i understood me the mere that them look the place melfreceimed this out then i dare very little thing their p |
4 | 1.0 | mr williams i understankant creet complete fallen like came to learn myself it seemishing the work uild its paston with i |
This is surprisingly good. ꉂ૮(°□°’˶)ა