{ "cells": [ { "cell_type": "markdown", "id": "7a4912ac", "metadata": { "papermill": { "duration": 0.009569, "end_time": "2025-01-13T14:53:19.267785", "exception": false, "start_time": "2025-01-13T14:53:19.258216", "status": "completed" }, "tags": [] }, "source": [ "# RNN model training" ] }, { "cell_type": "markdown", "id": "78c572f5", "metadata": { "papermill": { "duration": 0.018123, "end_time": "2025-01-13T14:53:19.294596", "exception": false, "start_time": "2025-01-13T14:53:19.276473", "status": "completed" }, "tags": [] }, "source": [ "Each mini-batch contains $B \\times T$ prediction instances for training. \n", "Prediction is done at each step, with the state updated at each step as well.\n", "Below are logits for $t = 1, \\ldots, 30.$ At each step, the hidden state is updated before making the next prediction. Finally, the model is evaluated at every time step with varying-length inputs. State starts at zero, so it may be necessary to warm the model up." ] }, { "cell_type": "code", "execution_count": 1, "id": "6ef4fed3", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:19.301186Z", "iopub.status.busy": "2025-01-13T14:53:19.300939Z", "iopub.status.idle": "2025-01-13T14:53:20.275472Z", "shell.execute_reply": "2025-01-13T14:53:20.275161Z" }, "papermill": { "duration": 0.979657, "end_time": "2025-01-13T14:53:20.277299", "exception": false, "start_time": "2025-01-13T14:53:19.297642", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa74fbd2", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.282804Z", "iopub.status.busy": "2025-01-13T14:53:20.282595Z", "iopub.status.idle": "2025-01-13T14:53:20.308962Z", "shell.execute_reply": "2025-01-13T14:53:20.308627Z" }, "papermill": { "duration": 0.032474, "end_time": "2025-01-13T14:53:20.312456", "exception": false, "start_time": "2025-01-13T14:53:20.279982", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
def collate_fn(batch):\n",
" """Transforming the data to sequence-first format."""\n",
" x, y = zip(*batch)\n",
" x = torch.stack(x, 1) # (T, B, vocab_size)\n",
" y = torch.stack(y, 1) # (T, B)\n",
" return x, y\n",
"
import torch.nn.functional as F\n",
"\n",
"def clip_grad_norm(model, max_norm: float):\n",
" """Calculate norm on concatenated params. Modify params in-place."""\n",
" params = [p for p in model.parameters() if p.requires_grad]\n",
" norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
" if norm > max_norm:\n",
" for p in params:\n",
" p.grad[:] *= max_norm / norm # [:] = shallow copy, in-place\n",
"
def train_step(model, optim, x, y, max_norm) -> float:\n",
" target = y.transpose(0, 1)\n",
" output = model(x).permute(1, 2, 0)\n",
" loss = F.cross_entropy(output, target)\n",
" loss.backward()\n",
" \n",
" clip_grad_norm(model, max_norm=max_norm)\n",
" optim.step()\n",
" optim.zero_grad()\n",
" return loss.item()\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def valid_step(model, x, y) -> float:\n",
" target = y.transpose(0, 1)\n",
" output = model(x).permute(1, 2, 0)\n",
" loss = F.cross_entropy(output, target)\n",
" return loss.item()\n",
"